Source code for neurobench.metrics.workload.mse

import torch
from torch import Tensor
from neurobench.metrics.utils.decorators import check_shapes

from neurobench.metrics.abstract.workload_metric import WorkloadMetric


[docs] class MSE(WorkloadMetric): """Mean squared error of the model predictions."""
[docs] def __init__(self): """Initialize the MSE metric.""" super().__init__(requires_hooks=False)
[docs] @check_shapes def __call__(self, model, preds: Tensor, data: Tensor) -> float: """ Compute mean squared error. Args: model: A NeuroBenchModel. preds: A tensor of model predictions. data: A tuple of data and labels. Returns: float: Mean squared error. """ return torch.mean((preds - data[1]) ** 2).item()