# Authors: SheepTAO <sheeptao@outlook.com>
# License: MIT
# Copyright the dpeeg contributors.
import os
from abc import abstractmethod, ABC
from copy import deepcopy
import torch
import torch.nn as nn
from torch import optim
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.backends import cudnn
from torchinfo import summary
from typing import Literal, Type
from torch.utils.tensorboard.writer import SummaryWriter
from torch.utils.data import TensorDataset, DataLoader, ConcatDataset
from torchmetrics.functional.classification.accuracy import multiclass_accuracy
from torchmetrics.aggregation import MeanMetric, CatMetric
from .base import Trainer
from ..tools import Logger, Timer
from ..utils import DPEEG_SEED
from ..transforms.functions import to_tensor
from .stopcriteria import ComposeStopCriteria
from .utils import get_device, get_device_name, model_depth
from ..utils import _set_torch_seed, mapping_to_str
from ..datasets.base import EEGData
cudnn.benchmark = False
cudnn.deterministic = True
class BaseClassifier(Trainer, ABC):
"""Classification model trainer.
Generate a trainer to test the performance of the same network on
different datasets.
"""
def __init__(
self,
model: Module,
loss_fn: str | Type[Module] = "NLLLoss",
loss_fn_args: dict | None = None,
optimizer: str | Type[Optimizer] = "Adam",
optimizer_args: dict | None = None,
lr: float = 1e-3,
lr_sch: str | Type[LRScheduler] | None = None,
lr_sch_args: dict | None = None,
grad_acc: int = 1,
batch_size: int = 32,
nGPU: int = 0,
seed: int = DPEEG_SEED,
keep_data_gpu: bool = True,
depth: int | None = None,
data_size: tuple | list | None = None,
verbose: int | str = "INFO",
) -> None:
super().__init__(model)
self.model = model
self.loger = Logger("dpeeg_train", clevel=verbose)
self.timer = Timer()
# init trainer
self.device = get_device(nGPU)
self.loger.info(
f"Model will be trained on {self.device} "
f"({get_device_name(self.device)})"
)
self.model.to(self.device)
_set_torch_seed(seed)
self.loger.info(f"Set torch random seed = {seed}")
# summarize model structure
self.model_arch = str(model) + "\n"
depth = model_depth(self.model) if depth is None else depth
self.model_arch += str(summary(model, data_size, depth=depth))
self.loger.info(self.model_arch)
# the type of optimizer, etc. selected
self.loss_fn_type = loss_fn
self.optimizer_type = optimizer
self.lr_sch_type = lr_sch
# save additional parameters
self.loss_fn_args = loss_fn_args if loss_fn_args else {}
self.optimizer_args = optimizer_args if optimizer_args else {}
self.lr_sch_args = lr_sch_args if lr_sch_args else {}
# --- others
self.optimizer_args["lr"] = lr
self.grad_acc = grad_acc
self.batch_size = batch_size
self.seed = seed
self.keep_data_gpu = keep_data_gpu
self.verbose = verbose
# set experimental details
self.train_details = {
"type": self.__class__.__name__,
"train_param": {
"seed": seed,
"loss_fn": str(loss_fn),
"loss_fn_args": loss_fn_args,
"optimizer": str(optimizer),
"optimizer_args": optimizer_args,
"lr_sch": str(lr_sch),
"lr_sch_args": lr_sch_args,
"batch_size": batch_size,
"grad_acc": grad_acc,
},
"orig_model_param": deepcopy(self.model.state_dict()),
}
def fit_epoch(self, train_loader: DataLoader) -> None:
"""Fit one epoch to train model.
Parameters
----------
train_loader : DataLoader
DataLoader used for training.
"""
# set the model in training mode
self.model.train()
# iterater over all the data
with torch.enable_grad():
for idx, (data, label) in enumerate(train_loader):
data, label = data.to(self.device), label.to(self.device)
out = self.model(data)
loss = self.loss_fn(out, label)
loss.backward()
# gradient accumulation
if (idx + 1) % self.grad_acc == 0:
# 1 - update parameters
self.optimizer.step()
# 2 - zero the parameter gradients
self.optimizer.zero_grad()
# update lr
# Note: Learning rate scheduling should be applied after optimizer’s update
if self.lr_sch:
self.lr_sch.step()
def predict(
self,
data_loader: DataLoader,
) -> tuple[Tensor, Tensor, Tensor]:
"""Predict the class of the input data.
Parameters
----------
data_loader : DataLoader
Dataset used for prediction.
Returns
-------
preds : Tensor
Predicted labels, as returned by a classifier.
target : Tensor
Ground truth (correct) labels.
loss : Tensor
Average loss.
"""
# set the model in the eval mode
self.model.eval()
loss_sum = MeanMetric()
preds, target = CatMetric(), CatMetric()
# iterate over all the data
with torch.no_grad():
for data, label in data_loader:
data, label = data.to(self.device), label.to(self.device)
out = self.model(data)
loss = self.loss_fn(out, label)
loss_sum.update(loss.item(), data.size(0))
# convert the output of soft-max to class label
# save preds and actual label
out = out[0] if isinstance(out, tuple) else out
preds.update(torch.argmax(out, dim=1).detach().cpu())
target.update(label.cpu())
return preds.compute(), target.compute(), loss_sum.compute()
def data_loader(self, *datasets: EEGData) -> DataLoader:
"""Wrap multiple sets of data and labels and return DataLoader.
Parameters
----------
datasets : sequence of EEGData
Sequence of EEGData. Multiple EEGData will be concatenated.
"""
if len(datasets) == 0:
raise ValueError("At least one dataset required as input.")
# dataset wrapping tensors
td = []
for dataset in datasets:
data, label = to_tensor(dataset["edata"], dataset["label"])
if self.keep_data_gpu:
data, label = data.to(self.device), label.to(self.device)
td.append(TensorDataset(data, label))
td = ConcatDataset(td)
return DataLoader(td, self.batch_size, True)
def _reset_fitter(self, log_dir: str) -> tuple[str, SummaryWriter, Logger]:
"""Reset the relevant parameters of the fitter.
Reset the model's training parameters, learning rate schedule and
optimizer etc. to their initialized state.
Parameters
----------
log_dir : str
Directory location (support hierarchical folder structure) to save
training log.
Returns
-------
str, SummaryWriter, Logger:
Return the absolute file path and a new SummaryWriter object and
logger manager for the fitter.
"""
# reset parameters of nn.Moudle
self.model.load_state_dict(self.train_details["orig_model_param"])
# create loss function
if isinstance(self.loss_fn_type, str):
self.loss_fn = getattr(nn, self.loss_fn_type)(**self.loss_fn_args)
else:
self.loss_fn = self.loss_fn_type(**self.loss_fn_args)
# create optimizer
if isinstance(self.optimizer_type, str):
self.optimizer = getattr(optim, self.optimizer_type)(
self.model.parameters(), **self.optimizer_args
)
elif issubclass(self.optimizer_type, Optimizer):
self.optimizer = self.optimizer_type(
self.model.parameters(), self.optimizer_args
)
else:
raise TypeError(
f"Optimizer type ({type(self.optimizer_type)}) is not supported."
)
# create lr_scheduler
if self.lr_sch_type is None:
self.lr_sch = None
elif isinstance(self.lr_sch_type, str):
self.lr_sch = getattr(optim.lr_scheduler, self.lr_sch_type)(
self.optimizer, **self.lr_sch_args
)
elif issubclass(self.lr_sch_type, LRScheduler):
self.lr_sch = self.lr_sch_type(self.optimizer, **self.lr_sch_args)
else:
raise TypeError(f"lr_sch type ({type(self.lr_sch_type)}) is not supported.")
# create log writer
log_dir = os.path.abspath(log_dir)
writer = SummaryWriter(log_dir)
loger = Logger(
log_dir,
path=os.path.join(log_dir, "running.log"),
flevel="INFO",
clevel=self.verbose,
)
return log_dir, writer, loger
def get_acc(self, preds: Tensor, target: Tensor, ncls: int) -> Tensor:
"""Easy for program to caculate the accuarcy."""
return multiclass_accuracy(preds, target, ncls, "micro")
def __repr__(self) -> str:
s = "[Model architecture]:\n" + self.model_arch + "\n"
s += f"[Loss function]: {self.loss_fn_type}"
if self.loss_fn_args:
s += f"({mapping_to_str(self.loss_fn_args)})\n"
else:
s += "\n"
s += f"[Optimizer]: {self.optimizer_type}"
if self.optimizer_args:
s += f"({mapping_to_str(self.optimizer_args)})\n"
else:
s += "\n"
if self.lr_sch_type:
s += f"[Lr scheduler]: {self.lr_sch_type}"
if self.lr_sch_args:
s += f"({mapping_to_str(self.lr_sch_args)})\n"
else:
s += "\n"
s += f"[Grad Acc]: {self.grad_acc}\n"
s += f"[Batch Size]: {self.batch_size}\n"
s += f"[Seed]: {self.seed}\n"
return s
[docs]
class Classifier(BaseClassifier):
"""Classifier training.
During different training model processes, early stopping mechanisms can be
executed using the training set (validation set not required) to select the
model.
Parameters
----------
model : Module
Inherit Module and should define the forward method. The first
parameter returned by model forward propagation is the prediction.
loss_fn : str, Type[Module]
Name of the loss function from `torch.nn` which will be used for
training. If Module, means using a custom loss function.
Note: custom loss_fn is a class (not an instance), and its
initialization list is `(**loss_fn_args)`.
loss_fn_args : dict, optional
Additional arguments to be passed to the loss function.
optimizer : str, Type[Optimizer]
Name of the optimization function from torch.optim which will be used
for training. If Optimizer, means using a custom optimizer.
Note: custom optimizer is a class (not an instance), and its
initialization list is `(model, lr=lr, **optimizer_args)`.
optimizer_args : dict, optional
Additional arguments to be passed to the optimization function.
lr : float
Learning rate.
lr_sch : str, Type[LRScheduler], optional
Name of the learning scheduler from `torch.optim.lr_scheduler` which
will be used for training. If LRScheduler, means using a custom
learning scheduler.
Note: custom learning scheduler is a class (not an instance), and its
initialization list is `(optimizer, **lr_sch_args)`.
lr_sch_args : dict, optional
Additional arguments to be passed to the lr_scheduler function.
grad_acc : int
Aradient accumulation.
batch_size : int
Mini-batch size.
max_epochs : int
Maximum number of epochs in training.
no_increase_epochs : int
Maximum number of consecutive epochs when the accuracy or loss of the
training set has no relative improvement.
var_check : str
The best value (train_inacc/train_loss) to check while determining the
best model which will be used to evaluate its performance on the test.
load_best_state : bool
If `True`, the best model parameters will be used for evaluation.
nGPU : int
Select the gpu id to train. If the GPU is not available then the CPU is
used.
seed : int
Select random seed for review.
keep_data_gpu : bool
Keep the dataset on the GPU to avoid the time consumption of data
migration. Please adjust according to the personal GPU memory.
data_size : tuple, list, optional
Output the structure of the network model according to the input
dimension if the `data_size` is given.
depth : int, optional
Depth of nested layers to display.
verbose : int, str
The log level of console. Default is INFO. Mainly used for debug.
"""
def __init__(
self,
model: Module,
loss_fn: str | Type[Module] = "NLLLoss",
loss_fn_args: dict | None = None,
optimizer: str | Type[Optimizer] = "Adam",
optimizer_args: dict | None = None,
lr: float = 0.001,
lr_sch: str | Type[LRScheduler] | None = None,
lr_sch_args: dict | None = None,
grad_acc: int = 1,
batch_size: int = 32,
max_epochs: int = 1000,
no_increase_epochs: int = 50,
var_check: Literal["train_loss", "train_inacc"] = "train_loss",
load_best_state: bool = True,
nGPU: int = 0,
seed: int = DPEEG_SEED,
keep_data_gpu: bool = True,
depth: int | None = None,
data_size: tuple | list | None = None,
verbose: int | str = "INFO",
) -> None:
super().__init__(
model,
loss_fn,
loss_fn_args,
optimizer,
optimizer_args,
lr,
lr_sch,
lr_sch_args,
grad_acc,
batch_size,
nGPU,
seed,
keep_data_gpu,
depth,
data_size,
verbose,
)
self.max_epochs = max_epochs
self.no_increase_epochs = no_increase_epochs
self.var_check = var_check
self.load_best_state = load_best_state
self.train_details.update(
{
"train_param": {
"max_epochs": max_epochs,
"no_increase_epochs": no_increase_epochs,
"var_check": var_check,
"load_best_state": load_best_state,
}
}
)
[docs]
def fit(
self,
trainset: EEGData,
testset: EEGData,
log_dir: str,
) -> dict[str, dict[str, Tensor]]:
"""Fit the model.
Parameters
----------
trainset : EEGData
Dataset used for training.
testset : EEGData
Dataset used to evaluate the model.
log_dir : str
The path to save the training log.
Returns
-------
result : dict
Returns the training set and test set results (including true
labels, predicted labels and accuracy).
"""
log_dir, writer, loger = self._reset_fitter(log_dir)
# check the best model
best_var = float("inf")
best_model_param = deepcopy(self.model.state_dict())
# initialize dataloader
self.trainset = trainset
self.testset = testset
train_loader = self.data_loader(trainset)
test_loader = self.data_loader(testset)
train_ncls = trainset.ncls
test_ncls = testset.ncls
# start the training
self.timer.start()
loger.info(f"[Training...] - [{self.timer.ctime()}]")
loger.info(f"[Train/Test] - [{trainset.trials()}/{testset.trials()}]")
stopcri = ComposeStopCriteria(
{
"Or": {
"cri1": {
"MaxEpoch": {"max_epochs": self.max_epochs, "var_name": "epoch"}
},
"cri2": {
"NoDecrease": {
"num_epochs": self.no_increase_epochs,
"var_name": self.var_check,
}
},
}
}
)
monitors = {"epoch": 0, "train_loss": float("inf"), "train_inacc": 1}
while not stopcri(monitors):
# train one epoch
self.fit_epoch(train_loader)
monitors["epoch"] += 1
# evaluate the training and validation accuracy
train_preds, train_target, train_loss = self.predict(train_loader)
train_acc = self.get_acc(train_preds, train_target, train_ncls)
monitors["train_inacc"] = 1 - train_acc
monitors["train_loss"] = train_loss
# store loss and acc
writer.add_scalars(
"train", {"loss": train_loss, "acc": train_acc}, monitors["epoch"]
)
loger.info(f'-->Epoch : {monitors["epoch"]}')
loger.info(f" \u21b3train Loss/Acc = {train_loss:.4f}/{train_acc:.4f}")
# select best model
if self.load_best_state and monitors[self.var_check] <= best_var:
best_var = monitors[self.var_check]
best_model_param = deepcopy(self.model.state_dict())
writer.close()
if not self.load_best_state:
best_model_param = deepcopy(self.model.state_dict())
# report the checkpoint time of end and compute cost time
h, m, s = self.timer.stop()
loger.info(f"[Train Finish] - [Cost Time = {h}H:{m}M:{s:.2f}S]")
# load the best model and evaulate this model in testset
self.model.load_state_dict(best_model_param)
results = {}
train_preds, train_target, train_loss = self.predict(train_loader)
train_acc = self.get_acc(train_preds, train_target, train_ncls)
results["train"] = {
"preds": train_preds,
"target": train_target,
"acc": train_acc,
}
test_preds, test_target, test_loss = self.predict(test_loader)
test_acc = self.get_acc(test_preds, test_target, train_ncls)
results["test"] = {"preds": test_preds, "target": test_target, "acc": test_acc}
loger.info(f"Loss: Train={train_loss:.4f} | Test={test_loss:.4f}")
loger.info(f"Acc: Train={train_acc:.4f} | Test={test_acc:.4f}")
self.train_details["results"] = results
self.train_details["best_model_param"] = best_model_param
# store the training details
train_details_path = os.path.join(log_dir, f"train_details.pt")
torch.save(self.train_details, train_details_path)
# store the best model parameters
best_checkpoiont_path = os.path.join(log_dir, f"best_checkpoint.pth")
torch.save(best_model_param, best_checkpoiont_path)
return results
[docs]
class ClassifierTwoStage(BaseClassifier):
"""Two-stage classifier training.
Two-stage training strategy was used. In the first stage, the model was
trained using only the training set with the early stopping criteria
whereby the validation set accuracy and loss was monitored and training was
stopped if there was no increase in the validation set accuracy (or loss)
for consecutive 200 epochs. After reaching the stopping criteria, network
parameters with the best validation set accuracy (or loss) were restored.
In the second stage, the model was trained with the complete training data
(train + validation set). The second stage training was stopped when the
validation set loss reduced below the first stage training set loss.
Parameters
----------
model : Module
Inherit Module and should define the forward method. The first
parameter returned by model forward propagation is the prediction.
loss_fn : str, Type[Module]
Name of the loss function from `torch.nn` which will be used for
training. If Module, means using a custom loss function.
Note: custom loss_fn is a class (not an instance), and its
initialization list is `(**loss_fn_args)`.
loss_fn_args : dict, optional
Additional arguments to be passed to the loss function.
optimizer : str, Type[Optimizer]
Name of the optimization function from torch.optim which will be used
for training. If Optimizer, means using a custom optimizer.
Note: custom optimizer is a class (not an instance), and its
initialization list is `(model, lr=lr, **optimizer_args)`.
optimizer_args : dict, optional
Additional arguments to be passed to the optimization function.
lr : float
Learning rate.
lr_sch : str, Type[LRScheduler], optional
Name of the learning scheduler from `torch.optim.lr_scheduler` which
will be used for training. If LRScheduler, means using a custom
learning scheduler.
Note: custom learning scheduler is a class (not an instance), and its
initialization list is `(optimizer, **lr_sch_args)`.
lr_sch_args : dict, optional
Additional arguments to be passed to the lr_scheduler function.
grad_acc : int
Aradient accumulation.
batch_size : int
Mini-batch size.
max_epochs_s1, max_epochs_s2 : int
Maximum number of epochs in the x stage of training.
no_increase_epochs : int
Maximum number of consecutive epochs when the accuracy or loss of the
first-stage validation set has no relative improvement.
second_stage : bool
If `True`, two-stage training will be performed.
load_best_state : bool
If `True`, two-stage will retrain based on the best state in first-
stage. Otherwise, two-stage will retain based on the last state of the
first-stage.
var_check : str
The best value (valid_inacc/valid_loss) to check while determining the
best state which will be used for parameter initialization in the
second stage of model training.
cls_name : list of str
The name of dataset labels.
nGPU : int
Select the gpu id to train. If the GPU is not available then the CPU is
used.
seed : int
Select random seed for review.
keep_data_gpu : bool
Keep the dataset on the GPU to avoid the time consumption of data
migration. Please adjust according to the personal GPU memory.
data_size : tuple, list, optional
Output the structure of the network model according to the input
dimension if the `data_size` is given.
depth : int, optional
Depth of nested layers to display.
verbose : int, str
The log level of console. Default is INFO. Mainly used for debug.
"""
def __init__(
self,
model: Module,
loss_fn: str | Type[Module] = "NLLLoss",
loss_fn_args: dict | None = None,
optimizer: str | Type[Optimizer] = "Adam",
optimizer_args: dict | None = None,
lr: float = 0.001,
lr_sch: str | Type[LRScheduler] | None = None,
lr_sch_args: dict | None = None,
grad_acc: int = 1,
batch_size: int = 32,
max_epochs_s1: int = 1500,
max_epochs_s2: int = 600,
no_increase_epochs: int = 200,
second_stage: bool = True,
var_check: Literal["valid_inacc", "valid_loss"] = "valid_inacc",
load_best_state: bool = True,
nGPU: int = 0,
seed: int = DPEEG_SEED,
keep_data_gpu: bool = True,
depth: int | None = None,
data_size: tuple | list | None = None,
verbose: int | str = "INFO",
) -> None:
super().__init__(
model,
loss_fn,
loss_fn_args,
optimizer,
optimizer_args,
lr,
lr_sch,
lr_sch_args,
grad_acc,
batch_size,
nGPU,
seed,
keep_data_gpu,
depth,
data_size,
verbose,
)
self.max_epochs_s1 = max_epochs_s1
self.max_epochs_s2 = max_epochs_s2
self.no_increase_epochs = no_increase_epochs
self.second_stage = second_stage
self.var_check = var_check
self.load_best_state = load_best_state
self.train_details.update(
{
"train_param": {
"max_epochs_s1": max_epochs_s1,
"max_epochs_s2": max_epochs_s2,
"no_increase_epochs": no_increase_epochs,
"second_stage": second_stage,
"var_check": var_check,
"load_best_state": load_best_state,
}
}
)
[docs]
def fit(
self,
trainset: EEGData,
validset: EEGData,
testset: EEGData,
log_dir: str,
):
"""Fit the model.
Parameters
----------
trainset : EEGData
Dataset used for training.
validset : EEGData
Dataset used for validation.
testset : EEGData
Dataset used to evaluate the model.
log_dir : str
The path to save the training log.
Returns
-------
dict
Returns the training set, validation set, and test set results
(including true labels, predicted labels and accuracy).
"""
log_dir, writer, loger = self._reset_fitter(log_dir)
# check the best model
best_var = float("inf")
best_model_param = deepcopy(self.model.state_dict())
best_optim_param = deepcopy(self.optimizer.state_dict())
# initialize dataloader
self.trainset = trainset
self.validset = validset
self.testset = testset
train_loader = self.data_loader(trainset)
valid_loader = self.data_loader(validset)
test_loader = self.data_loader(testset)
train_ncls = trainset.ncls
valid_ncls = validset.ncls
test_ncls = testset.ncls
# start the training
self.timer.start()
loger.info(f"[Training...] - [{self.timer.ctime()}]")
loger.info(
f"[Train/Valid/Test] - "
f"[{trainset.trials()}/{validset.trials()}/{testset.trials()}]"
)
stopcri = ComposeStopCriteria(
{
"Or": {
"cri1": {
"MaxEpoch": {
"max_epochs": self.max_epochs_s1,
"var_name": "epoch",
}
},
"cri2": {
"NoDecrease": {
"num_epochs": self.no_increase_epochs,
"var_name": self.var_check,
}
},
}
}
)
self.train_details["fit"] = {
"type": "fit_with_val",
"var_check": self.var_check,
"stopcri_1": str(stopcri),
}
monitors = {
"epoch": 0,
"valid_loss": float("inf"),
"valid_inacc": 1,
"global_epoch": 0,
"best_epoch": -1,
"best_train_loss": float("inf"),
}
load_best_state = self.load_best_state
early_stop_reached, do_stop = False, False
while not do_stop:
# train one epoch
self.fit_epoch(train_loader)
monitors["epoch"] += 1
monitors["global_epoch"] += 1
# evaluate the training and validation accuracy
train_preds, train_target, train_loss = self.predict(train_loader)
train_acc = self.get_acc(train_preds, train_target, train_ncls)
valid_preds, valid_target, valid_loss = self.predict(valid_loader)
valid_acc = self.get_acc(valid_preds, valid_target, valid_ncls)
monitors["valid_inacc"] = 1 - valid_acc
monitors["valid_loss"] = valid_loss
# store loss and acc
writer.add_scalars(
"train",
{"loss": train_loss, "acc": train_acc},
monitors["global_epoch"],
)
writer.add_scalars(
"valid",
{"loss": valid_loss, "acc": valid_acc},
monitors["global_epoch"],
)
# print the epoch info
loger.info(f'-->Epoch : {monitors["epoch"]}')
loger.info(
f" \u21b3train Loss/Acc = {train_loss:.4f}/{train_acc:.4f}"
f" | valid Loss/Acc = {valid_loss:.4f}/{valid_acc:.4f}"
)
# select best model on Stage 1
if load_best_state and monitors[self.var_check] <= best_var:
best_var = monitors[self.var_check]
best_model_param = deepcopy(self.model.state_dict())
best_optim_param = deepcopy(self.optimizer.state_dict())
monitors["best_train_loss"] = train_loss
monitors["best_epoch"] = monitors["epoch"]
# check whether to stop training
if stopcri(monitors):
# check whether to enter the second stage of training
if self.second_stage and not early_stop_reached:
early_stop_reached = True
epoch = monitors["epoch"]
# load the best state
if load_best_state:
self.model.load_state_dict(best_model_param)
self.optimizer.load_state_dict(best_optim_param)
train_loss = monitors["best_train_loss"]
epoch = monitors["best_epoch"]
loger.info("[Early Stopping Reached] -> Training on full set.")
loger.info(f"[Epoch = {epoch} | Loss = {train_loss:.4f}]")
# combine the train and valid dataset
train_loader = self.data_loader(trainset, validset)
# update stop monitor and epoch
stopcri = ComposeStopCriteria(
{
"Or": {
"cri1": {
"MaxEpoch": {
"max_epochs": self.max_epochs_s2,
"var_name": "epoch",
}
},
"cri2": {
"Smaller": {
"var": train_loss,
"var_name": "valid_loss",
}
},
}
}
)
self.train_details["fit"]["stopcri_2"] = str(stopcri)
monitors["epoch"] = 0
load_best_state = False
elif self.second_stage and early_stop_reached:
do_stop = True
best_model_param = deepcopy(self.model.state_dict())
# no second stage
else:
do_stop = True
if not load_best_state:
best_model_param = deepcopy(self.model.state_dict())
writer.close()
# report the checkpoint time of end and compute cost time
h, m, s = self.timer.stop()
loger.info(f"[Train Finish] - [Cost Time = {h}H:{m}M:{s:.2f}S]")
# load the best model and evaulate this model in testset
self.model.load_state_dict(best_model_param)
results = {}
train_preds, train_target, train_loss = self.predict(train_loader)
train_acc = self.get_acc(train_preds, train_target, train_ncls)
results["train"] = {
"preds": train_preds,
"target": train_target,
"acc": train_acc,
}
valid_preds, valid_target, valid_loss = self.predict(valid_loader)
valid_acc = self.get_acc(valid_preds, valid_target, valid_ncls)
results["valid"] = {
"preds": valid_preds,
"target": valid_target,
"acc": valid_acc,
}
test_preds, test_target, test_loss = self.predict(test_loader)
test_acc = self.get_acc(test_preds, test_target, test_ncls)
results["test"] = {
"preds": test_preds,
"target": test_target,
"acc": test_acc,
}
loger.info(
f"Loss: Train={train_loss:.4f} | Valid={valid_loss:.4f} | "
f"test={test_loss:.4f}"
)
loger.info(
f"Acc: Train={train_acc:.4f} | Valid={valid_acc:.4f} | "
f"Test={test_acc:.4f}"
)
self.train_details["results"] = results
self.train_details["best_model_param"] = best_model_param
# store the training details
train_details_path = os.path.join(log_dir, f"train_details.pt")
torch.save(self.train_details, train_details_path)
# store the best model
best_checkpoiont_path = os.path.join(log_dir, f"best_checkpoint.pth")
torch.save(best_model_param, best_checkpoiont_path)
return results