Source code for neurobench.processors.postprocessors.aggregate

from neurobench.processors.abstract.postprocessor import NeuroBenchPostProcessor


[docs] class Aggregate(NeuroBenchPostProcessor): """Returns aggregated spikes."""
[docs] def __call__(self, spikes): """ Returns the aggregated spikes. Args: spikes (Tensor): A torch tensor of spikes of shape (batch, timestep, classes) Returns: Tensor: A torch tensor of spikes of shape (batch, classes) """ return spikes.sum(1)