Source code for cl_gym.algorithms.ogd

import torch
import time
from torch import nn
from torch import optim
from cl_gym.algorithms import ContinualAlgorithm
from cl_gym.algorithms.utils import flatten_grads, assign_grads


[docs]class OGD(ContinualAlgorithm): """ | Orthogonal Gradient Descent | By Farajtabar et al. : https://arxiv.org/abs/1910.07104.pdf """ # implementation is partially based on: https://github.com/MehdiAbbanaBennani/continual-learning-ogdplus/ def __init__(self, backbone, benchmark, params): super(OGD, self).__init__(backbone, benchmark, params) self.gradient_storage = [] self.orthonormal_basis = None @torch.no_grad() def __update_orthonormal_basis(self): q, r = torch.qr(torch.stack(self.gradient_storage).T) self.orthonormal_basis = q.T @torch.no_grad() def __project_grad_vector(self, g): # print("Projection || Shape check >> g={}, basis={}".format(g.shape, self.orthonormal_basis.shape)) mid = (torch.mm(self.orthonormal_basis, g.view(-1, 1))).T res = torch.mm(mid, self.orthonormal_basis) # print("New grad shape >> ", res.shape) return res.view(-1) def __update_gradient_storage(self, task): mem_loader_train, _ = self.benchmark.load_memory(task, batch_size=1) criteriton = self.prepare_criterion(self.current_task) optimizer = self.prepare_optimizer(self.current_task) self.backbone.train() for inp, targ, task_id in mem_loader_train: optimizer.zero_grad() pred = self.backbone(inp) loss = criteriton(pred, targ) loss.backward() grad_batch = flatten_grads(self.backbone).detach().clone() self.gradient_storage.append(grad_batch) optimizer.zero_grad()
[docs] def training_task_end(self): self.__update_gradient_storage(self.current_task) self.__update_orthonormal_basis()
[docs] def training_step(self, task_id, inp, targ, optimizer, criterion): optimizer.zero_grad() pred = self.backbone(inp, task_id) loss = criterion(pred, targ) loss.backward() pred = self.backbone(inp) loss = criterion(pred, targ) loss.backward() if task_id > 1: grad_batch = flatten_grads(self.backbone).detach().clone() optimizer.zero_grad() proj_grad = self.__project_grad_vector(grad_batch) new_grad = grad_batch - proj_grad self.backbone = assign_grads(self.backbone, new_grad) optimizer.step()