Source code for cl_gym.benchmarks.mnist

import torchvision
from typing import Optional, Tuple
from cl_gym.benchmarks.utils import DEFAULT_DATASET_DIR
from cl_gym.benchmarks.transforms import get_default_mnist_transform
from cl_gym.benchmarks.transforms import get_default_rotation_mnist_transform
from cl_gym.benchmarks.transforms import get_default_permuted_mnist_transform
from cl_gym.benchmarks.base import Benchmark, DynamicTransformDataset, SplitDataset


[docs]class ContinualMNIST(Benchmark): """ Base class for (Permuted/Rotated/Split)-MNIST benchmarks. """ def __init__(self, num_tasks: int, per_task_examples: Optional[int] = None, per_task_joint_examples: Optional[int] = 0, per_task_memory_examples: Optional[int] = 0, per_task_subset_examples: Optional[int] = 0, task_input_transforms: Optional[list] = None, task_target_transforms: Optional[list] = None): """ Args: num_tasks: The number of tasks for the benchmark. per_task_examples: If set, each task will include part of the original benchmark rather than full data. per_task_joint_examples: If set, the benchmark will support joint/multitask loading of tasks. per_task_memory_examples: If set, the benchmark will support episodic memory/replay buffer loading of tasks. per_task_subset_examples: If set, the benchmark will support loading a pre-defined subset of each task. task_input_transforms: If set, the benchmark will use the provided torchvision transform. task_target_transforms: If set, the benchmark will use the provided target transform for targets. . note:: If :attr:`task_input_transforms` or :attr:`task_target_transforms`, they should be a list of size `num_tasks` where each element of the list can be a torchvision (Composed) transform. """ super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples, per_task_subset_examples, task_input_transforms, task_target_transforms) self.load_datasets() self.prepare_datasets() def __load_mnist(self): self.mnist_train = torchvision.datasets.MNIST(DEFAULT_DATASET_DIR, train=True, download=True) self.mnist_test = torchvision.datasets.MNIST(DEFAULT_DATASET_DIR, train=False, download=True)
[docs] def load_datasets(self): self.__load_mnist() for task in range(1, self.num_tasks+1): input_transform = self.task_input_transforms[task-1] target_transform = self.task_target_transforms[task-1] if self.task_target_transforms else None self.trains[task] = DynamicTransformDataset(task, self.mnist_train, input_transform, target_transform) self.tests[task] = DynamicTransformDataset(task, self.mnist_test, input_transform, target_transform)
[docs] def precompute_memory_indices(self): for task in range(1, self.num_tasks + 1): indices_train = self.sample_uniform_class_indices(self.trains[task].dataset, 0, 9, self.per_task_memory_examples) indices_test = self.sample_uniform_class_indices(self.tests[task].dataset, 0, 9, self.per_task_memory_examples) assert len(indices_train) == len(indices_test) == self.per_task_memory_examples self.memory_indices_train[task] = indices_train[:] self.memory_indices_test[task] = indices_test[:]
[docs]class RotatedMNIST(ContinualMNIST): """ Rotated MNIST benchmark. """ def __init__(self, num_tasks: int, per_task_examples: Optional[int] = None, per_task_joint_examples: Optional[int] = 0, per_task_memory_examples: Optional[int] = 0, per_task_subset_examples: Optional[int] = 0, task_input_transforms: Optional[list] = None, task_target_transforms: Optional[list] = None, per_task_rotation: Optional[float] = None): """ Args: num_tasks: The number of tasks for the benchmark. per_task_examples: If set, each task will include part of the original benchmark rather than full data. per_task_joint_examples: If set, the benchmark will support joint/multitask loading of tasks. per_task_memory_examples: If set, the benchmark will support episodic memory/replay buffer loading of tasks. per_task_subset_examples: If set, the benchmark will support loading a pre-defined subset of each task. task_input_transforms: If set, the benchmark will use the provided torchvision transform. task_target_transforms: If set, the benchmark will use the provided target transform for targets. per_task_rotation: rotation degrees per task. . note:: The first task won't have any rotations. e.g., if rotation per task is 45 degrees, first task has 0 deg rotation, second task has 45 degree, etc. """ if task_input_transforms is None: task_input_transforms = get_default_rotation_mnist_transform(num_tasks, per_task_rotation) super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples, per_task_subset_examples, task_input_transforms, task_target_transforms)
[docs]class PermutedMNIST(ContinualMNIST): """ Permuted MNIST benchmark """ def __init__(self, num_tasks: int, per_task_examples: Optional[int] = None, per_task_joint_examples: Optional[int] = 0, per_task_memory_examples: Optional[int] = 0, per_task_subset_examples: Optional[int] = 0, task_input_transforms: Optional[list] = None, task_target_transforms: Optional[list] = None): """ Args: num_tasks: The number of tasks for the benchmark. per_task_examples: If set, each task will include part of the original benchmark rather than full data. per_task_joint_examples: If set, the benchmark will support joint/multitask loading of tasks. per_task_memory_examples: If set, the benchmark will support episodic memory/replay buffer loading of tasks. per_task_subset_examples: If set, the benchmark will support loading a pre-defined subset of each task. task_input_transforms: If set, the benchmark will use the provided torchvision transform. task_target_transforms: If set, the benchmark will use the provided target transform for targets. . note:: The first task will be the default MNIST and the permutation applies to tasks 2, 3, ..., `num_tasks`. """ if task_input_transforms is None: task_input_transforms = get_default_permuted_mnist_transform(num_tasks) super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples, per_task_subset_examples, task_input_transforms, task_target_transforms)
class SplitMNIST(ContinualMNIST): """ Split MNIST benchmark. The benchmark can have at most 5 tasks, each a binary classification on MNIST digits. """ def __init__(self, num_tasks: int, per_task_examples: Optional[int] = None, per_task_joint_examples: Optional[int] = 0, per_task_memory_examples: Optional[int] = 0, per_task_subset_examples: Optional[int] = 0, task_input_transforms: Optional[list] = None, task_target_transforms: Optional[list] = None): if num_tasks > 5: raise ValueError("Split MNIST benchmark can have at most 5 tasks (i.e., 10 classes, 2 per task)") if task_input_transforms is None: task_input_transforms = get_default_mnist_transform(num_tasks) super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples, per_task_subset_examples, task_input_transforms, task_target_transforms) def __load_mnist(self): transforms = self.task_input_transforms[0] self.mnist_train = torchvision.datasets.MNIST(DEFAULT_DATASET_DIR, train=True, download=True, transform=transforms) self.mnist_test = torchvision.datasets.MNIST(DEFAULT_DATASET_DIR, train=False, download=True, transform=transforms) def load_datasets(self): self.__load_mnist() for task in range(1, self.num_tasks + 1): self.trains[task] = SplitDataset(task, 2, self.mnist_train) self.tests[task] = SplitDataset(task, 2, self.mnist_test) def precompute_memory_indices(self): for task in range(1, self.num_tasks + 1): start_cls = (task - 1) * 2 end_cls = task * 2 - 1 num_examples = self.per_task_memory_examples indices_train = self.sample_uniform_class_indices(self.trains[task], start_cls, end_cls, num_examples) indices_test = self.sample_uniform_class_indices(self.tests[task], start_cls, end_cls, num_examples) assert len(indices_train) == len(indices_test) == self.per_task_memory_examples self.memory_indices_train[task] = indices_train[:] self.memory_indices_test[task] = indices_test[:]