Source code for neurobench.metrics.workload.r2

import torch
from neurobench.metrics.abstract.workload_metric import AccumulatedMetric
from neurobench.metrics.utils.decorators import check_shapes


[docs] class R2(AccumulatedMetric): """ R2 Score of the model predictions. Currently implemented for 2D output only. """
[docs] def __init__(self): """ Initalize metric state. Must hold memory of all labels seen so far. """ super().__init__(requires_hooks=False) self.x_sum_squares = 0.0 self.y_sum_squares = 0.0 self.x_labels = None self.y_labels = None
[docs] def reset(self): """Reset metric state.""" self.x_sum_squares = 0.0 self.y_sum_squares = 0.0 self.x_labels = torch.tensor([]) self.y_labels = torch.tensor([])
[docs] @check_shapes def __call__(self, model, preds, data): """ Args: model: A NeuroBenchModel. preds: A tensor of model predictions. data: A tuple of data and labels. Returns: float: R2 Score. """ self.x_sum_squares += torch.sum((data[1][:, 0] - preds[:, 0]) ** 2).item() self.y_sum_squares += torch.sum((data[1][:, 1] - preds[:, 1]) ** 2).item() if self.x_labels is None: self.x_labels = data[1][:, 0] self.y_labels = data[1][:, 1] else: self.x_labels = torch.cat((self.x_labels, data[1][:, 0])) self.y_labels = torch.cat((self.y_labels, data[1][:, 1])) return self.compute()
[docs] def compute(self): """Compute r2 score using accumulated data.""" x_denom = self.x_labels.var(correction=0) * len(self.x_labels) y_denom = self.y_labels.var(correction=0) * len(self.y_labels) x_r2 = 1 - (self.x_sum_squares / x_denom) y_r2 = 1 - (self.y_sum_squares / y_denom) r2 = (x_r2 + y_r2) / 2 return r2.item()