ClassifierTwoStage#

class dpeeg.trainer.ClassifierTwoStage(model: Module, loss_fn: str | Type[Module] = 'NLLLoss', loss_fn_args: dict | None = None, optimizer: str | Type[Optimizer] = 'Adam', optimizer_args: dict | None = None, lr: float = 0.001, lr_sch: str | Type[LRScheduler] | None = None, lr_sch_args: dict | None = None, grad_acc: int = 1, batch_size: int = 32, max_epochs_s1: int = 1500, max_epochs_s2: int = 600, no_increase_epochs: int = 200, second_stage: bool = True, var_check: Literal['valid_inacc', 'valid_loss'] = 'valid_inacc', load_best_state: bool = True, nGPU: int = 0, seed: int = 42, keep_data_gpu: bool = True, depth: int | None = None, data_size: tuple | list | None = None, verbose: int | str = 'INFO')[source]#

Two-stage classifier training.

Two-stage training strategy was used. In the first stage, the model was trained using only the training set with the early stopping criteria whereby the validation set accuracy and loss was monitored and training was stopped if there was no increase in the validation set accuracy (or loss) for consecutive 200 epochs. After reaching the stopping criteria, network parameters with the best validation set accuracy (or loss) were restored. In the second stage, the model was trained with the complete training data (train + validation set). The second stage training was stopped when the validation set loss reduced below the first stage training set loss.

Parameters:
  • model (Module) – Inherit Module and should define the forward method. The first parameter returned by model forward propagation is the prediction.

  • loss_fn (str, Type[Module]) – Name of the loss function from torch.nn which will be used for training. If Module, means using a custom loss function. Note: custom loss_fn is a class (not an instance), and its initialization list is (**loss_fn_args).

  • loss_fn_args (dict, optional) – Additional arguments to be passed to the loss function.

  • optimizer (str, Type[Optimizer]) – Name of the optimization function from torch.optim which will be used for training. If Optimizer, means using a custom optimizer. Note: custom optimizer is a class (not an instance), and its initialization list is (model, lr=lr, **optimizer_args).

  • optimizer_args (dict, optional) – Additional arguments to be passed to the optimization function.

  • lr (float) – Learning rate.

  • lr_sch (str, Type[LRScheduler], optional) – Name of the learning scheduler from torch.optim.lr_scheduler which will be used for training. If LRScheduler, means using a custom learning scheduler. Note: custom learning scheduler is a class (not an instance), and its initialization list is (optimizer, **lr_sch_args).

  • lr_sch_args (dict, optional) – Additional arguments to be passed to the lr_scheduler function.

  • grad_acc (int) – Aradient accumulation.

  • batch_size (int) – Mini-batch size.

  • max_epochs_s1 (int) – Maximum number of epochs in the x stage of training.

  • max_epochs_s2 (int) – Maximum number of epochs in the x stage of training.

  • no_increase_epochs (int) – Maximum number of consecutive epochs when the accuracy or loss of the first-stage validation set has no relative improvement.

  • second_stage (bool) – If True, two-stage training will be performed.

  • load_best_state (bool) – If True, two-stage will retrain based on the best state in first- stage. Otherwise, two-stage will retain based on the last state of the first-stage.

  • var_check (str) – The best value (valid_inacc/valid_loss) to check while determining the best state which will be used for parameter initialization in the second stage of model training.

  • cls_name (list of str) – The name of dataset labels.

  • nGPU (int) – Select the gpu id to train. If the GPU is not available then the CPU is used.

  • seed (int) – Select random seed for review.

  • keep_data_gpu (bool) – Keep the dataset on the GPU to avoid the time consumption of data migration. Please adjust according to the personal GPU memory.

  • data_size (tuple, list, optional) – Output the structure of the network model according to the input dimension if the data_size is given.

  • depth (int, optional) – Depth of nested layers to display.

  • verbose (int, str) – The log level of console. Default is INFO. Mainly used for debug.

fit(trainset: EEGData, validset: EEGData, testset: EEGData, log_dir: str)[source]#

Fit the model.

Parameters:
  • trainset (EEGData) – Dataset used for training.

  • validset (EEGData) – Dataset used for validation.

  • testset (EEGData) – Dataset used to evaluate the model.

  • log_dir (str) – The path to save the training log.

Returns:

Returns the training set, validation set, and test set results (including true labels, predicted labels and accuracy).

Return type:

dict