# Authors: SheepTAO <sheeptao@outlook.com>
# License: MIT
# Copyright the dpeeg contributors.
from sklearn.model_selection import StratifiedKFold
from torchmetrics.aggregation import MeanMetric, CatMetric
from .base import ClsExp
from ..trainer.classifier import BaseClassifier
from ..datasets.base import SplitEEGData
from ..transforms.base import ToEEGData, SplitTrainTest
from ..tools import Filer
from ..utils import DPEEG_SEED, get_init_args
from ..tools.docs import fill_doc
[docs]
@fill_doc
class KFold(ClsExp):
r"""K-Fold cross validation experiment.
The KFold experiment divides the dataset into K non-overlapping subsets
(i.e., "folds") and repeatedly trains and tests the model. The purpose is
to reduce the dependence of the model evaluation results on the way the
dataset is divided and to improve the stability and reliability of the
evaluation results. However, its computational cost is high, especially for
large datasets and complex models. It may take a long time to complete the
training of all folds.
Two validation methods are provided in the experiment, determined by the
parameter ``isolate_testset``. When set to ``True``, it indicates that the
test set is independent of the k-fold cross-validation. That is, for each
fold, the data is divided into a training set and a validation set to find
the optimal parameters for each fold, and then the model is evaluated on an
independent dataset. When set to ``False``, it indicates that one fold of
data in each fold is used as the test set, and the remaining folds are used
to train the model. The average value of all folds' evaluations is used as
the performance metric of the model. The specific experimental method is
shown in the figure below, which illustrates a 3-fold cross-validation
experiment:
.. image:: ../_static/images/kfold_isolate_testset.png
:align: center
:alt: kfold isolate testset
When the training set and test set come from different sessions, setting
this parameter is very useful.
Parameters
----------
%(trainer)s
%(out_folder)s
k : int, optional
k of k-Fold.
isolate_testset : bool
By default, the test set is independent, that is, the k-fold cross-
validation at this time only divides the training set and the
verification set based on the training set to implement an early
stopping mechanism, and finally evaluates on the isolated test set.
If False, the test set is for each fold of k-fold cross-validation.
shuffle : bool
Shuffle before kfold.
seed : int
Seed of random for review.
%(timestamp)s
Notes
-----
If ``isolate_testset`` False, please provide the ``transforms`` parameter
of the ``run`` function to avoid data leakage caused by operations such
as data augmentation in advance. When set to ``True``, it means that the
experiment requires the `trainer` to support a validation set.
"""
def __init__(
self,
trainer: BaseClassifier,
out_folder: str | None = None,
k: int = 5,
isolate_testset: bool = True,
shuffle: bool = True,
seed: int = DPEEG_SEED,
timestamp: bool = True,
verbose: int | str = "INFO",
) -> None:
super().__init__(
get_init_args(self, locals(), ret_dict=True),
trainer=trainer,
out_folder=out_folder,
timestamp=timestamp,
verbose=verbose,
)
self.k = str(k)
self.isolate_testset = isolate_testset
self.skf = StratifiedKFold(k, shuffle=shuffle, random_state=seed)
def _run_sub_classifier(self, eegdata, sub_folder):
if self.isolate_testset:
raise TypeError("Isolated test set is useless with a `Classifier` trainer.")
eegdata = ToEEGData()(eegdata, verbose=False)
result = {}
filer = Filer(sub_folder / "summary.txt")
train_acc_metric = MeanMetric()
test_acc_metric = MeanMetric()
preds_metric = CatMetric()
target_metric = CatMetric()
for exp_idx, (train_idx, test_idx) in enumerate(
self.skf.split(eegdata["edata"], eegdata["label"]), start=1
):
self.logger.info(f"\n# ---- {sub_folder.name}_exp{exp_idx} ---- #")
trainset = eegdata.index(train_idx)
testset = eegdata.index(test_idx)
egd = self._trans_eegdata(SplitEEGData(trainset, testset))
train_set = egd["train"]
test_set = egd["test"]
exp_folder = sub_folder / f"exp{exp_idx}"
exp_result = self.trainer.fit(
trainset=train_set,
testset=test_set,
log_dir=exp_folder,
)
result[f"exp{exp_idx}"] = exp_result
train_acc_metric.update(exp_result["train"]["acc"])
test_acc_metric.update(exp_result["test"]["acc"])
preds_metric.update(exp_result["test"]["preds"])
target_metric.update(exp_result["test"]["target"])
filer.write(
f"Exp_{str(exp_idx).zfill(len(self.k))} Acc: "
f"Train={exp_result['train']['acc']:.4f} | "
f"Test={exp_result['test']['acc']:.4f}\n"
)
train_acc = train_acc_metric.compute()
test_acc = test_acc_metric.compute()
filer.write(f"Avg Acc = {test_acc*100:.2f}%\n")
self.logger.info("-" * 30)
self.logger.info(f"Avg Acc: Train={train_acc:.4f} | Test={test_acc:.4f}")
result.update(
{
"acc": test_acc,
"preds": preds_metric.compute(),
"target": target_metric.compute(),
}
)
return result
def _run_sub_classifier_two_stage(self, eegdata, sub_folder):
if self.isolate_testset:
eegdata = self._trans_eegdata(eegdata)
X = eegdata["train"]["edata"]
y = eegdata["train"]["label"]
else:
eegdata = ToEEGData()(eegdata, verbose=False)
X = eegdata["edata"]
y = eegdata["label"]
result = {}
filer = Filer(sub_folder / "summary.txt")
train_acc_metric = MeanMetric()
valid_acc_metric = MeanMetric()
test_acc_metric = MeanMetric()
preds_metric = CatMetric()
target_metric = CatMetric()
for exp_idx, (train_idx, test_idx) in enumerate(self.skf.split(X, y), start=1):
self.logger.info(f"\n# ---- {sub_folder.name}_exp{exp_idx} ---- #")
if self.isolate_testset:
train_set = eegdata["train"].index(train_idx)
valid_set = eegdata["train"].index(test_idx)
test_set = eegdata["test"]
else:
trainset = eegdata.index(train_idx) # type: ignore
testset = eegdata.index(test_idx) # type: ignore
egd = self._trans_eegdata(SplitEEGData(trainset, testset))
test_set = egd["test"]
trainset = egd["train"]
egd = SplitTrainTest()(trainset)
train_set = egd["train"]
valid_set = egd["test"]
exp_folder = sub_folder / f"exp{exp_idx}"
exp_result = self.trainer.fit(
trainset=train_set,
validset=valid_set,
testset=test_set,
log_dir=exp_folder,
)
result[f"exp{exp_idx}"] = exp_result
train_acc_metric.update(exp_result["train"]["acc"])
valid_acc_metric.update(exp_result["valid"]["acc"])
test_acc_metric.update(exp_result["test"]["acc"])
preds_metric.update(exp_result["test"]["preds"])
target_metric.update(exp_result["test"]["target"])
filer.write(
f"Exp_{str(exp_idx).zfill(len(self.k))} Acc: "
f"Train={exp_result['train']['acc']:.4f} | "
f"Valid={exp_result['train']['acc']:.4f} | "
f"Test={exp_result['test']['acc']:.4f}\n"
)
train_acc = train_acc_metric.compute()
valid_acc = valid_acc_metric.compute()
test_acc = test_acc_metric.compute()
filer.write(f"Avg Acc = {test_acc*100:.2f}%\n")
self.logger.info("-" * 30)
self.logger.info(
f"Avg Acc: Train={train_acc:.4f} | Valid={valid_acc:.4f} | "
f"Test={test_acc:.4f}"
)
result.update(
{
"acc": test_acc,
"preds": preds_metric.compute(),
"target": target_metric.compute(),
}
)
return result
def _run_sub(self, eegdata, sub_folder):
"""Basic K-Fold cross validation function.
Returns
-------
Return test_acc, test_kappa, test_preds, test_target and results :\n
results = {
'expNo_1' : { ... },
'expNo_2' : { ... },
.
.
}
"""
self.timer.start("kfold")
self.logger.info(f"\n# ---------- {sub_folder.name} ---------- #")
result = self._run_sub_func(eegdata, sub_folder)
h, m, s = self.timer.stop("kfold")
self.logger.info(
f"\n[{self.k}Fold CV Finish] - [Cost Time = {h}H:{m}M:{s:.2f}S]"
)
return result["acc"], result["preds"], result["target"], result