# Authors: SheepTAO <sheeptao@outlook.com>
# License: MIT
# Copyright the dpeeg contributors.
from pathlib import Path
from abc import ABC, abstractmethod
from inspect import getmro
import torch
from torchmetrics.aggregation import CatMetric, MeanMetric
import dpeeg
from ..datasets.base import SplitEEGData, BaseDataset, BaseData
from ..transforms.base import Transforms
from ..trainer.base import Trainer
from ..trainer.classifier import BaseClassifier
from ..tools import Logger, Timer, Filer
from ..utils import DPEEG_DIR, iterable_to_str, _format_log, _format_log_kv
class Experiment(ABC):
"""Experiment base class.
Parameters
----------
repr : str
Basic information on the experiment method.
trainer : Trainer
Trainer used for training module on dataset.
out_folder : str, optional
Store all experimental results in a folder named with the model class
name in the specified folder. Default is
'~/dpeeg/out/model/exp/dataset/timestamp'.
timestamp : bool
Output folders are timestamped.
verbose : int, str
The log level of console. Default is INFO. Mainly used for debugging.
Notes
-----
The training results of all models for each subject will be saved under the
`out_folder` directory.
"""
def __init__(
self,
repr: dict,
trainer: Trainer,
out_folder: str | None = None,
timestamp: bool = True,
verbose: int | str = "INFO",
) -> None:
repr.pop("trainer")
self._repr = repr
self.trainer = trainer
self.timestamp = timestamp
self.verbose = verbose
# create logger and timer
self.logger = Logger("dpeeg_exp", clevel=verbose)
self.timer = Timer()
# set output folder
net = trainer.model.__class__.__name__
exp = self.__class__.__name__
self.out_folder = (
Path(out_folder).absolute().joinpath(net, exp)
if out_folder
else Path(DPEEG_DIR).joinpath("out", net, exp)
)
@abstractmethod
def _run(self) -> dict:
"""Reconstruct different experimental processes according to different
experimental designs.
"""
pass
[docs]
def run(
self,
dataset: BaseDataset,
transforms: Transforms | None = None,
dataset_name: str | None = None,
desc: str | None = None,
) -> dict:
"""Train models separately for each subject.
This function will internally call the `_run_sub` function for each
subject, and save the final results together.
Parameters
----------
dataset : EEG Data or Dataset
The dataset used for the experimental test.
transforms : Transforms, optional
Apply pre-transforms on dataset. Transformations will be apply
during the experiment on each subject's dataset. The rationable
behind this method lies in deferring data manipulation, especially
for certain transformations that could potentially enlarge the
dataset's memory footprint. This delay allows for the execution of
data manipulation after subject-independent experiment have
concatenated the relevant data (Time for Space) or the experiment
subject are ready, mitigating the risk of memory overflow.
dataset_name : str, optional
The dataset name to use. If ``None``, The default name of the
dataset will be used as the folder to save experimental results.
desc : str, optional
Add a short description to the current experiment.
Returns
-------
Return a dict of all subjects and corresponding experimental results.
"""
if dataset_name:
self.data_folder = self.out_folder / dataset_name
else:
self.data_folder = self.out_folder / dataset._repr["_obj_name"]
if self.timestamp:
self.data_folder = self.data_folder / Timer.cdate()
self.data_folder.mkdir(parents=True, exist_ok=False)
self.logger.info(f"Results saved in `{self.data_folder}`")
self.dataset = dataset
self.transforms = transforms
self.filer = Filer(self.data_folder / "summary.txt")
self.filer.write(f"[Start Time]: {self.timer.ctime()}\n")
self.filer.write(f"[DPEEG Version]: {dpeeg.__version__}\n")
self.filer.write(f"[Description]: {desc}\n")
self.filer.write(str(dataset) + "\n")
self.filer.write(_format_log_kv("Transforms", transforms) + "\n")
self.filer.write(_format_log_kv("Trainer", self.trainer) + "\n")
self.filer.write(str(self) + "\n")
self.timer.start()
self.logger.info("=" * 50)
# Start the experiment for all subjects.
results = self._run()
h, m, s = self.timer.stop()
torch.save(results, self.data_folder / f"results.pt")
self.logger.info(
f"\n[All Subjects Finished] - [Cost Time = {h}H:{m}M:{s:.2f}S]"
)
self.logger.info("=" * 50)
return results
def __repr__(self) -> str:
return _format_log(self._repr)
class ClsExp(Experiment, ABC):
"""Base class for classification experiments."""
@abstractmethod
def _run_sub_classifier(self, *args, **kwargs) -> dict:
pass
@abstractmethod
def _run_sub_classifier_two_stage(self, *args, **kwargs) -> dict:
pass
def __init__(
self,
repr: dict,
trainer: BaseClassifier,
out_folder: str | None = None,
timestamp: bool = True,
verbose: int | str = "INFO",
) -> None:
super().__init__(
repr=repr,
trainer=trainer,
out_folder=out_folder,
timestamp=timestamp,
verbose=verbose,
)
trainer_type = [base.__name__ for base in getmro(type(trainer))]
trainer_list = {
"Classifier": self._run_sub_classifier,
"ClassifierTwoStage": self._run_sub_classifier_two_stage,
}
inter = set(trainer_list) & set(trainer_type)
if len(inter) == 0:
raise TypeError(
f"Trainer type {trainer_type} is not supported, "
f"only {iterable_to_str(trainer_list.keys())} are supported."
)
else:
self._run_sub_func = trainer_list[inter.pop()]
def _trans_eegdata(self, eegdata: BaseData) -> SplitEEGData:
"""Apply pre-transforms on eegdata.
Raises
------
TypeError
If the eegdata is not split after transformed.
"""
if self.transforms is not None:
eegdata = self.transforms(eegdata)
if not isinstance(eegdata, SplitEEGData):
raise TypeError("The eegdata is not split.")
return eegdata
def _process_sub_dataset(self, subject: int):
"""Preprocess each subject's dataset.
Different preprocessing operations are performed on the dataset accord-
ing to different experimental requirement. By default, eegdata for each
subject in the dataset is returned.
"""
return self.dataset[subject]
@abstractmethod
def _run_sub(self, eegdata: BaseData, sub_folder: Path):
"""Train a model on the specified subject data.
This function will be called by `_run` function to conduct experiments
on the data of each individual subject. Reconstruct the model training
process according to different experimental requirements.
Parameters
----------
subject : int
Subject of the experiment. Create a subdirectory of `out_folder` to
store all yield results during subject training.
eegdata : eegdata
Subject eegdata. Adjust the subject eegdata according to different
experiments.
Returns
-------
result : dict
Contains the `acc`, `preds`, `target` and detailed results.
"""
pass
def _run(self) -> dict:
result = {}
acc_metric = MeanMetric()
preds_metric = CatMetric()
target_metric = CatMetric()
for subject in self.dataset.keys():
self.logger.info(f"\n[Subject-{subject} Training ...]")
self.logger.info("-" * 50)
eegdata = self._process_sub_dataset(subject)
sub_folder = self.data_folder / f"sub{subject}"
sub_folder.mkdir(parents=True, exist_ok=False)
acc, preds, target, subject_result = self._run_sub(eegdata, sub_folder) # type: ignore
result[f"subject_{subject}"] = subject_result
acc_metric.update(acc)
preds_metric.update(preds)
target_metric.update(target)
self.filer.write(f"Subject_{subject} Acc = {acc*100:.2f}%\n")
acc = acc_metric.compute()
self.filer.write(f"Model Acc = {acc*100:.2f}%\n")
self.logger.info("-" * 50)
self.logger.info(f"[Model Acc = {acc*100:.2f}%]")
result.update(
{
"acc": acc,
"preds": preds_metric.compute(),
"target": target_metric.compute(),
}
)
return result