Source code for cl_gym.benchmarks.cifar

import torchvision
from typing import Optional, Tuple
from cl_gym.benchmarks.utils import DEFAULT_DATASET_DIR
from cl_gym.benchmarks import Benchmark, SplitDataset
from cl_gym.benchmarks.transforms import get_default_cifar_transform


class SplitCIFAR(Benchmark):
    """
    Base class for Split-CIFAR benchmarks.
    """
    def __init__(self,
                 num_tasks: int,
                 per_task_examples: Optional[int] = None,
                 per_task_joint_examples: Optional[int] = 0,
                 per_task_memory_examples: Optional[int] = 0,
                 per_task_subset_examples: Optional[int] = 0,
                 task_input_transforms: Optional[list] = None,
                 task_target_transforms: Optional[list] = None,
                 is_cifar_100: bool = True):
        """
        Args:
            num_tasks: Number of tasks. 20 for CIFAR-100 and 5 for CIFAR-10.
            per_task_examples: If set, each task will include part of the original benchmark rather than full data.
            per_task_joint_examples: If set, the benchmark will support joint/multitask loading of tasks.
            per_task_memory_examples: If set, the benchmark will support episodic memory/replay buffer loading of tasks.
            per_task_subset_examples: If set, the benchmark will support loading a pre-defined subset of each task.
            task_input_transforms: If set, the benchmark will use the provided torchvision transform.
            task_target_transforms: If set, the benchmark will use the provided target transform for targets.
            is_cifar_100: If true, it will set prepare for CIFAR-100, otherwise CIFAR-10.
            
        . note::
            If :attr:`task_input_transforms` or :attr:`task_target_transforms`, they should be a list
            of size `num_tasks` where each element of the list can be a torchvision (Composed) transform.
            
        . note::
            Similar to Torchvision, CIFAR benchmarks in CL-Gym are first loaded in the memory for faster loading.
        """
        
        # CIFAR-100 vs CIFAR-10 book-keeping variables:
        # CIFAR-100 has 20 tasks (5 classes per task)
        # CIFAR-10  has 5 tasks (2 classes per task)
        self.is_cifar_100 = is_cifar_100
        if task_input_transforms is None:
            task_input_transforms = get_default_cifar_transform(num_tasks, is_cifar_100)
        self.num_classes_per_split = 5 if self.is_cifar_100 else 2
        super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples,
                         per_task_subset_examples, task_input_transforms, task_target_transforms)
        self.load_datasets()
        self.prepare_datasets()

    def __load_cifar(self):
        transforms = self.task_input_transforms[0]
        CIFAR_dataset = torchvision.datasets.CIFAR100 if self.is_cifar_100 else torchvision.datasets.CIFAR10
        self.cifar_train = CIFAR_dataset(DEFAULT_DATASET_DIR, train=True, download=True, transform=transforms)
        self.cifar_test = CIFAR_dataset(DEFAULT_DATASET_DIR, train=False, download=True, transform=transforms)

    def load_datasets(self):
        """
        Loades CIFAR-Dataset [In memory]
        """
        self.__load_cifar()
        for task in range(1, self.num_tasks + 1):
            self.trains[task] = SplitDataset(task, self.num_classes_per_split, self.cifar_train)
            self.tests[task] = SplitDataset(task, self.num_classes_per_split, self.cifar_test)
    
    def precompute_memory_indices(self):
        """
        Precomputes memory indices for each task.
        
        . note::
            The default behavior is class-uniform sampling.
            i.e., each class will have roughly equal number of samples in the memory.
            You can inherit this class and override this method for custom behavior. But a better way
            is to move this logic to your algorithm component's code.
        """
        for task in range(1, self.num_tasks + 1):
            start_cls = (task - 1) * self.num_classes_per_split
            end_cls = task * self.num_classes_per_split - 1
            num_examples = self.per_task_memory_examples
            indices_train = self.sample_uniform_class_indices(self.trains[task], start_cls, end_cls, num_examples)
            indices_test = self.sample_uniform_class_indices(self.tests[task], start_cls, end_cls, num_examples)
            assert len(indices_train)  == self.per_task_memory_examples
            self.memory_indices_train[task] = indices_train[:]
            self.memory_indices_test[task] = indices_test[:]


[docs]class SplitCIFAR100(SplitCIFAR): """ Split CIFAR-100 benchmark. Has 20 tasks, each task with 5 classes of CIFAR-100. """ def __init__(self, num_tasks: int, per_task_examples: Optional[int] = None, per_task_joint_examples: Optional[int] = 0, per_task_memory_examples: Optional[int] = 0, per_task_subset_examples: Optional[int] = 0, task_input_transforms: Optional[list] = None, task_target_transforms: Optional[list] = None): super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples, per_task_subset_examples, task_input_transforms, task_target_transforms, is_cifar_100=True)
[docs]class SplitCIFAR10(SplitCIFAR): """ Split CIFAR-10 benchmark. has 5 tasks, each with 2 classes of CIFAR-10. """ def __init__(self, num_tasks: int, per_task_examples: Optional[int] = None, per_task_joint_examples: Optional[int] = 0, per_task_memory_examples: Optional[int] = 0, per_task_subset_examples: Optional[int] = 0, task_input_transforms: Optional[list] = None, task_target_transforms: Optional[list] = None): super().__init__(num_tasks, per_task_examples, per_task_joint_examples, per_task_memory_examples, per_task_subset_examples, task_input_transforms, task_target_transforms, is_cifar_100=False)