Classifier#
- class dpeeg.trainer.Classifier(model: Module, loss_fn: str | Type[Module] = 'NLLLoss', loss_fn_args: dict | None = None, optimizer: str | Type[Optimizer] = 'Adam', optimizer_args: dict | None = None, lr: float = 0.001, lr_sch: str | Type[LRScheduler] | None = None, lr_sch_args: dict | None = None, grad_acc: int = 1, batch_size: int = 32, max_epochs: int = 1000, no_increase_epochs: int = 50, var_check: Literal['train_loss', 'train_inacc'] = 'train_loss', load_best_state: bool = True, nGPU: int = 0, seed: int = 42, keep_data_gpu: bool = True, depth: int | None = None, data_size: tuple | list | None = None, verbose: int | str = 'INFO')[source]#
Classifier training.
During different training model processes, early stopping mechanisms can be executed using the training set (validation set not required) to select the model.
- Parameters:
model (Module) – Inherit Module and should define the forward method. The first parameter returned by model forward propagation is the prediction.
loss_fn (str, Type[Module]) – Name of the loss function from torch.nn which will be used for training. If Module, means using a custom loss function. Note: custom loss_fn is a class (not an instance), and its initialization list is (**loss_fn_args).
loss_fn_args (dict, optional) – Additional arguments to be passed to the loss function.
optimizer (str, Type[Optimizer]) – Name of the optimization function from torch.optim which will be used for training. If Optimizer, means using a custom optimizer. Note: custom optimizer is a class (not an instance), and its initialization list is (model, lr=lr, **optimizer_args).
optimizer_args (dict, optional) – Additional arguments to be passed to the optimization function.
lr (float) – Learning rate.
lr_sch (str, Type[LRScheduler], optional) – Name of the learning scheduler from torch.optim.lr_scheduler which will be used for training. If LRScheduler, means using a custom learning scheduler. Note: custom learning scheduler is a class (not an instance), and its initialization list is (optimizer, **lr_sch_args).
lr_sch_args (dict, optional) – Additional arguments to be passed to the lr_scheduler function.
grad_acc (int) – Aradient accumulation.
batch_size (int) – Mini-batch size.
max_epochs (int) – Maximum number of epochs in training.
no_increase_epochs (int) – Maximum number of consecutive epochs when the accuracy or loss of the training set has no relative improvement.
var_check (str) – The best value (train_inacc/train_loss) to check while determining the best model which will be used to evaluate its performance on the test.
load_best_state (bool) – If True, the best model parameters will be used for evaluation.
nGPU (int) – Select the gpu id to train. If the GPU is not available then the CPU is used.
seed (int) – Select random seed for review.
keep_data_gpu (bool) – Keep the dataset on the GPU to avoid the time consumption of data migration. Please adjust according to the personal GPU memory.
data_size (tuple, list, optional) – Output the structure of the network model according to the input dimension if the data_size is given.
depth (int, optional) – Depth of nested layers to display.
verbose (int, str) – The log level of console. Default is INFO. Mainly used for debug.