Source code for cl_gym.utils.callbacks.metric_manager

import os
import numpy as np
from typing import Optional, Callable, Dict
from cl_gym.utils.loggers import Logger
from cl_gym.utils.callbacks import ContinualCallback
from cl_gym.utils.callbacks.helpers import IntervalCalculator, Visualizer
from cl_gym.utils.metrics import ContinualMetric, PerformanceMetric, ForgettingMetric


[docs]class MetricCollector(ContinualCallback): """ Collects metrics during the learning. This callback can support various metrics such as average accuracy/error, and average forgetting. """ def __init__(self, num_tasks: int, epochs_per_task: Optional[int] = 1, collect_on_init: bool = False, collect_metrics_for_future_tasks: bool = False, eval_interval: str = 'epoch', eval_type: str = 'classification', tuner_callback: Optional[Callable[[float, bool], None]] = None): """ Args: num_tasks: The number of task for the learning experience. epochs_per_task: The number of epochs per task. collect_on_init: Should also collect metrics before training starts? collect_metrics_for_future_tasks: Should collect metrics for future tasks? (e.g.,, for forward-transfer) eval_interval: The intervals at which the algorithm will be evaluated. Can be either `task` or `epoch` eval_type: Is this a `classification` task or `regression` task? tuner_callback: Optional tuner callback than can be called with eval metrics for parameter optimization. """ self.num_tasks = num_tasks self.epochs_per_task = epochs_per_task self.collect_on_init = collect_on_init self.collect_metrics_for_future_tasks = collect_metrics_for_future_tasks self.eval_interval = eval_interval.lower() self.tuner_callback = tuner_callback self.save_dirs = ['plots', 'metrics'] self.eval_type = eval_type.lower() self.meters = self._prepare_meters() self.interval_calculator = IntervalCalculator(self.num_tasks, self.epochs_per_task, self.eval_interval) super(MetricCollector, self).__init__("MetricCollector", self.save_dirs) self._verify_inputs() def _prepare_meters(self) -> Dict[str, ContinualMetric]: if self.eval_type == 'classification': return {'accuracy': PerformanceMetric(self.num_tasks, self.epochs_per_task), 'forgetting': ForgettingMetric(self.num_tasks, self.epochs_per_task), 'loss': PerformanceMetric(self.num_tasks, self.epochs_per_task)} else: return {'loss': PerformanceMetric(self.num_tasks, self.epochs_per_task)} def _verify_inputs(self): if self.num_tasks <= 0 or self.epochs_per_task <= 0: raise ValueError("'num_tasks' and 'epochs_per_task' should be greater than 0") if self.eval_interval not in ['epoch', 'task']: raise ValueError("'eval_interval' should be either 'task' or 'epoch'") if self.eval_type not in ['classification', 'regression']: raise ValueError("'eval_type' for metrics should be either 'classification' or 'regression'") def _update_meters(self, task_learned: int, task_evaluated: int, metrics: dict, relative_step: int): if self.eval_type == 'classification': self.meters['loss'].update(task_learned, task_evaluated, metrics['loss'], relative_step) self.meters['accuracy'].update(task_learned, task_evaluated, metrics['accuracy'], relative_step) self.meters['forgetting'].update(task_learned, task_evaluated, metrics['accuracy'], relative_step) else: self.meters['loss'].update(task_learned, task_evaluated, metrics['loss'], relative_step) def _update_logger(self, trainer, task_evaluated: int, metrics: dict, global_step: int): if self.eval_type == 'classification': if trainer.logger: trainer.logger.log_metric(f'acc_{task_evaluated}', round(metrics['accuracy'], 2), global_step) trainer.logger.log_metric(f'loss_{task_evaluated}', round(metrics['loss'], 2), global_step) if 0 < trainer.current_task == task_evaluated: avg_acc = round(self.meters['accuracy'].compute(trainer.current_task), 2) print(f"Average accuracy >> {avg_acc}") if trainer.logger: trainer.logger.log_metric(f'average_acc', avg_acc, global_step) else: if trainer.logger: trainer.logger.log_metric(f'loss_{task_evaluated}', round(metrics['loss'], 2), global_step) if 0 < trainer.current_task == task_evaluated: avg_loss = round(self.meters['loss'].compute(trainer.current_task), 5) print(f"Average Loss >> {avg_loss}") if trainer.logger: trainer.logger.log_metric(f'average_loss', avg_loss, global_step) def _update_tuner(self, is_final_score: bool): if self.tuner_callback is None: return if self.eval_type == 'classification': score = self.meters['accuracy'].compute_final() else: score = self.meters['loss'].compute_final() self.tuner_callback(score, is_final_score)
[docs] def log_metrics(self, trainer, task_learned: int, task_evaluated: int, metrics: dict, global_step: int, relative_step: int): self._update_meters(task_learned, task_evaluated, metrics, relative_step) self._update_logger(trainer, task_evaluated, metrics, global_step) self._update_tuner(is_final_score=False)
def _collect_eval_metrics(self, trainer, start_task: int, end_task: int): global_step = trainer.current_task if self.eval_interval == 'task' else trainer.current_epoch relative_step = self.interval_calculator.get_step_within_task(trainer.current_epoch) for eval_task in range(start_task, end_task + 1): task_metrics = trainer.validate_algorithm_on_task(eval_task) print(f"[{global_step}] Eval metrics for task {eval_task} >> {task_metrics}") self.log_metrics(trainer, trainer.current_task, eval_task, task_metrics, global_step, relative_step)
[docs] def save_metrics(self): metrics = ['accuracy', 'loss'] if self.eval_type == 'classification' else ['loss'] for metric in metrics: filepath = os.path.join(self.save_paths['metrics'], metric + ".npy") with open(filepath, 'wb') as f: np.save(f, self.meters[metric].data)
def _prepare_plot_params(self): xticks = self.interval_calculator.get_tick_times() if self.collect_on_init: xticks = [0] + xticks return { 'show_legend': True, 'legend_loc': 'lower left', 'xticks': xticks, 'xlabel': 'Epochs', 'ylabel': 'Validation Accuracy' } def _extract_task_history(self, task): if self.collect_on_init: start, end = 0, self.interval_calculator.get_task_range(task)[1] offset = 0 if self.eval_interval == 'task' else self.epochs_per_task - 1 metrics = self.meters['accuracy'].get_raw_history(task, 0)[start+offset:end+offset] # print(metrics) elif self.collect_metrics_for_future_tasks: start, end = 1, self.interval_calculator.get_task_range(task)[1] metrics = self.meters['accuracy'].get_raw_history(task, 1) else: start, end = self.interval_calculator.get_task_range(task) metrics = self.meters['accuracy'].get_raw_history(task)[start-1:] if self.eval_interval == 'task': metrics = self.meters['accuracy'].get_raw_history(task) start = self.epochs_per_task-1 metrics = metrics[start::self.epochs_per_task] return range(start, end), metrics
[docs] def plot_metrics(self, logger: Optional[Logger] = None): if self.eval_type != 'classification': return plot_params = self._prepare_plot_params() data, labels = [], [] for task in range(1, self.num_tasks + 1): data.append(self._extract_task_history(task)) labels.append(f"Task{task}") Visualizer.line_plot(data, labels=labels, save_dir=self.save_paths['plots'], filename="metrics", cmap='g10', plot_params=plot_params, logger=logger)
[docs] def on_before_fit(self, trainer): if self.collect_on_init: print(f"---------------------------- Init -----------------------") self._collect_eval_metrics(trainer, 1, end_task=self.num_tasks)
[docs] def on_after_training_epoch(self, trainer): if self.eval_interval != 'epoch': return if self.collect_on_init or self.collect_metrics_for_future_tasks: self._collect_eval_metrics(trainer, 1, self.num_tasks) else: self._collect_eval_metrics(trainer, 1, trainer.current_task)
[docs] def on_before_training_task(self, trainer): print(f"---------------------------- Task {trainer.current_task+1} -----------------------")
[docs] def on_after_training_task(self, trainer): if self.eval_interval != 'task': return if self.collect_on_init or self.collect_metrics_for_future_tasks: self._collect_eval_metrics(trainer, 1, self.num_tasks) else: self._collect_eval_metrics(trainer, 1, trainer.current_task)
[docs] def on_after_fit(self, trainer): self._update_tuner(is_final_score=True) self.save_metrics() self.plot_metrics(trainer.logger)