Source code for neurobench.models.snntorch_models

import torch
import snntorch as snn
from snntorch import utils
from .neurobench_model import NeuroBenchModel


[docs] class SNNTorchModel(NeuroBenchModel): """The SNNTorch class wraps the forward pass of the SNNTorch framework and ensures that spikes are in the correct format for downstream NeuroBench components."""
[docs] def __init__(self, net, custom_forward=False): """ Init using a trained network. Args: net: A trained SNNTorch network. """ super().__init__() self.net = net self.net.eval() # add snntorch neuron layers as activation modules self.add_activation_module(snn.SpikingNeuron) self.custom_forward = custom_forward
[docs] def __call__(self, data): """ Executes the forward pass of SNNTorch models on data that follows the NeuroBench specification. Ensures spikes are compatible with downstream components. Args: data: A PyTorch tensor of shape (batch, timesteps, ...) Returns: spikes: A PyTorch tensor of shape (batch, timesteps, ...) """ if self.custom_forward: return self.net(data).transpose(0, 1) # utils.reset(self.net) does not seem to delete all traces for the synaptic neuron model if hasattr(self.net, "reset"): self.net.reset() else: utils.reset(self.net) spikes = [] # Data is expected to be shape (batch, timestep, features*) for step in range(data.shape[1]): spk_out, _ = self.net(data[:, step, ...]) spikes.append(spk_out) spikes = torch.stack(spikes).transpose(0, 1) return spikes
[docs] def __net__(self): """Returns the underlying network.""" return self.net