import torch
# from torch.utils.data import DataLoader, TensorDataset
[docs]class ContinualAlgorithm:
"""
| Base class for continual learning algorithms.
| It contains abstractions for implementing different algorithms, and also implementations shared among all algorithms.
| It can be used for Naive(Finetune) algorithm or Stable SGD algorithm by Mirzadeh et. al.
"""
def __init__(self, backbone, benchmark, params, requires_memory=False):
"""
:param backbone: Neural network.
:param benchmark: Benchmark object for preparing task training/memory data.
:param params:
:param requires_memory: Whether the algorithm needs episodic memory/replay buffer (e.g., A-GEM) or not.
"""
self.backbone = backbone
self.benchmark = benchmark
self.params = params
self.current_task = 1
self.requires_memory = requires_memory
self.episodic_memory_loader = None
self.episodic_memory_iter = None
[docs] def teardown(self):
pass
[docs] def update_episodic_memory(self):
self.episodic_memory_loader, _ = self.benchmark.load_memory_joint(self.current_task,
batch_size=self.params['batch_size_memory'],
shuffle=True,
pin_memory=True)
self.episodic_memory_iter = iter(self.episodic_memory_loader)
[docs] def sample_batch_from_memory(self):
try:
batch = next(self.episodic_memory_iter)
except StopIteration:
self.episodic_memory_iter = iter(self.episodic_memory_loader)
batch = next(self.episodic_memory_iter)
device = self.params['device']
inp, targ, task_id = batch
return inp.to(device), targ.to(device), task_id.to(device)
[docs] def training_task_end(self):
if self.requires_memory:
self.update_episodic_memory()
self.current_task += 1
[docs] def training_epoch_end(self):
pass
[docs] def training_step_end(self):
pass
[docs] def prepare_train_loader(self, task_id):
num_workers = self.params.get('num_dataloader_workers', 0)
return self.benchmark.load(task_id, self.params['batch_size_train'],
num_workers=num_workers, pin_memory=True)[0]
[docs] def prepare_validation_loader(self, task_id):
num_workers = self.params.get('num_dataloader_workers', 0)
return self.benchmark.load(task_id, self.params['batch_size_validation'],
num_workers=num_workers, pin_memory=True)[1]
[docs] def prepare_optimizer(self, task_id):
if self.params.get('learning_rate_decay'):
lr_lower_bound = self.params.get('learning_rate_lower_bound', 0.0)
lr = max(self.params['learning_rate'] * (self.params['learning_rate_decay'] ** (task_id-1)), lr_lower_bound)
else:
lr = self.params['learning_rate']
if self.params['optimizer'].lower() == 'sgd':
return torch.optim.SGD(self.backbone.parameters(), lr=lr, momentum=self.params['momentum'])
elif self.params['optimizer'].lower() == 'adam':
return torch.optim.Adam(self.backbone.parameters(), lr=lr)
else:
raise ValueError("Only 'SGD' and 'Adam' are accepted. To use another optimizer, override this method.")
[docs] def prepare_criterion(self, task_id):
return self.params['criterion']
[docs] def training_step(self, task_ids, inp, targ, optimizer, criterion):
optimizer.zero_grad()
pred = self.backbone(inp, task_ids)
loss = criterion(pred, targ)
loss.backward()
optimizer.step()