Source code for cl_gym.trainer.base

import torch
from typing import Dict, Iterable, Optional
from cl_gym.algorithms import ContinualAlgorithm
from cl_gym.utils.callbacks import ContinualCallback
from cl_gym.utils.loggers import Logger
from cl_gym.trainer.state_manager import TrainerStateManagerMixin
from cl_gym.trainer.callback_hooks import TrainerCallbackHookMixin


[docs]class ContinualTrainer(TrainerStateManagerMixin, TrainerCallbackHookMixin): """ Base class for Trainer component. Basically, orchestrates the training by implementing a state-machine. For further info, please see the guides on how the trainer works. """ def __init__(self, algorithm: ContinualAlgorithm, params: dict, callbacks=Iterable[ContinualCallback], logger: Optional[Logger] = None): self.params = params self.algorithm = algorithm self.callbacks = callbacks self.logger = logger
[docs] def setup(self): for cb in self.callbacks: cb.connect(self)
[docs] def teardown(self): pass
[docs] def train_algorithm_on_task(self, task: int): train_loader = self.algorithm.prepare_train_loader(task) optimizer = self.algorithm.prepare_optimizer(task) criterion = self.algorithm.prepare_criterion(task) device = self.params['device'] for epoch in range(1, self.params['epochs_per_task']+1): self.on_before_training_epoch() self.tick('epoch') self.algorithm.backbone.train() self.algorithm.backbone = self.algorithm.backbone.to(device) for batch_idx, (inp, targ, task_ids) in enumerate(train_loader): self.on_before_training_step() self.tick('step') self.algorithm.training_step(task_ids.to(device), inp.to(device), targ.to(device), optimizer, criterion) self.algorithm.training_step_end() self.on_after_training_step() self.algorithm.training_epoch_end() self.on_after_training_epoch() self.algorithm.training_task_end()
[docs] def validate_algorithm_on_task(self, task: int, validate_on_train: bool = False) -> Dict[str, float]: self.algorithm.backbone.eval() device = self.params['device'] self.algorithm.backbone = self.algorithm.backbone.to(device) test_loss = 0 correct = 0 total = 0 if validate_on_train: eval_loader = self.algorithm.prepare_train_loader(task) else: eval_loader = self.algorithm.prepare_validation_loader(task) criterion = self.algorithm.prepare_criterion(task) with torch.no_grad(): for (inp, targ, task_ids) in eval_loader: inp, targ, task_ids = inp.to(device), targ.to(device), task_ids.to(device) pred = self.algorithm.backbone(inp, task_ids) total += len(targ) test_loss += criterion(pred, targ).item() pred = pred.data.max(1, keepdim=True)[1] correct += pred.eq(targ.data.view_as(pred)).sum() test_loss /= total correct = correct.cpu() avg_acc = 100.0 * float(correct.numpy()) / total return {'accuracy': avg_acc, 'loss': test_loss}
[docs] def fit(self): for task in range(1, self.params['num_tasks']+1): self.on_before_training_task() self.tick('task') self.train_algorithm_on_task(task) self.on_after_training_task()
def _run_setup(self): self.on_before_setup() self.setup() self.on_after_setup() def _run_fit(self): self.on_before_fit() self.fit() self.on_after_fit() def _run_teardown(self): self.on_before_teardown() self.teardown() self.on_after_teardown()
[docs] def run(self): # setup self._run_setup() # fit: main training loop self._run_fit() # teardown self._run_teardown()