Source code for dpeeg.exps.holdout

# Authors: SheepTAO <sheeptao@outlook.com>

# License: MIT
# Copyright the dpeeg contributors.

from .base import ClsExp
from ..trainer.classifier import BaseClassifier
from ..transforms.base import SplitTrainTest
from ..utils import get_init_args
from ..tools.docs import fill_doc


[docs] @fill_doc class HoldOut(ClsExp): """Holdout cross validation experiment. Validate the performance of the model on unseen data using holdout cross validation. Only one training and testing is required, so it is fast and suitable for large-scale datasets and fast model evaluation. Parameters ---------- %(trainer)s %(out_folder)s %(timestamp)s """ def __init__( self, trainer: BaseClassifier, out_folder: str | None = None, timestamp: bool = True, verbose: int | str = "INFO", ) -> None: super().__init__( get_init_args(self, locals(), ret_dict=True), trainer=trainer, out_folder=out_folder, timestamp=timestamp, verbose=verbose, ) def _run_sub_classifier(self, eegdata, sub_folder): result = self.trainer.fit( trainset=eegdata["train"], testset=eegdata["test"], log_dir=sub_folder, ) return result def _run_sub_classifier_two_stage(self, eegdata, sub_folder): split_eegdata = SplitTrainTest()(eegdata["train"]) train_set = split_eegdata["train"] valid_set = split_eegdata["test"] test_set = eegdata["test"] result = self.trainer.fit( trainset=train_set, validset=valid_set, testset=test_set, log_dir=sub_folder, ) return result def _run_sub(self, eegdata, sub_folder): self.logger.info(f"\n# ---------- {sub_folder.name} ---------- #") result = self._run_sub_func(self._trans_eegdata(eegdata), sub_folder) return ( result["test"]["acc"], result["test"]["preds"], result["test"]["target"], result, )