Source code for cl_gym.algorithms.mtl

import torch
from torch import nn
from torch import optim
from cl_gym.algorithms import ContinualAlgorithm


[docs]class Multitask(ContinualAlgorithm): """ Multitask (Joint) Training """ def __init__(self, backbone, benchmark, params): super(Multitask, self).__init__(backbone, benchmark, params, requires_memory=True)
[docs] def prepare_train_loader(self, task_id): return self.benchmark.load_joint(task_id, self.params['batch_size_train'], shuffle=True, pin_memory=True)[0]