Source code for neurobench.processors.postprocessors.aggregate
from neurobench.processors.abstract.postprocessor import NeuroBenchPostProcessor
[docs]
class Aggregate(NeuroBenchPostProcessor):
"""Returns aggregated spikes."""
[docs]
def __call__(self, spikes):
"""
Returns the aggregated spikes.
Args:
spikes (Tensor): A torch tensor of spikes of shape (batch, timestep, classes)
Returns:
Tensor: A torch tensor of spikes of shape (batch, classes)
"""
return spikes.sum(1)