Source code for neurobench.metrics.workload.membrane_updates

import torch
from neurobench.metrics.abstract.workload_metric import AccumulatedMetric
from collections import defaultdict


[docs] class MembraneUpdates(AccumulatedMetric): """ Membrane potential updates metric. This metric computes the number of membrane potential updates occurring during the forward pass of the model. The updates are tracked per neuron, per layer. """
[docs] def __init__(self): """Initialize the MembraneUpdates metric.""" super().__init__(requires_hooks=True) self.total_samples = 0 self.neuron_membrane_updates = defaultdict(int)
[docs] def reset(self): """Reset the metric state for a new evaluation.""" self.total_samples = 0 self.neuron_membrane_updates = defaultdict(int)
[docs] def __call__(self, model, preds, data): """ Accumulate the number of membrane updates for each model forward pass. Args: model: A NeuroBenchModel. preds: A tensor of model predictions. data: A tuple of data and labels. Returns: float: Number of membrane potential updates. """ for hook in model.activation_hooks: layer_type = hook.layer.__class__.__name__ updates = self.neuron_membrane_updates[layer_type] # Vectorized computation of updates if len(hook.pre_fire_mem_potential) > 1: pre_fire_mem = torch.stack(hook.pre_fire_mem_potential[1:]) post_fire_mem = torch.stack(hook.post_fire_mem_potential[1:]) updates += torch.count_nonzero(pre_fire_mem - post_fire_mem).item() # Add the number of elements in the first post_fire_mem_potential if hook.post_fire_mem_potential: updates += hook.post_fire_mem_potential[0].numel() # Update the dictionary self.neuron_membrane_updates[layer_type] = updates # Increment total_samples self.total_samples += data[0].size(0) # Return computed results return self.compute()
[docs] def compute(self): """ Compute the total membrane updates normalized by the number of samples. Returns: float: Compute the total updates to each neuron's membrane potential within the model, aggregated across all neurons and normalized by the number of samples processed. """ if self.total_samples == 0: return 0 total_mem_updates = sum(self.neuron_membrane_updates.values()) return total_mem_updates / self.total_samples