neurobench.models

Torch Models

class neurobench.models.torch_model.TorchModel(net)[source]

Bases: NeuroBenchModel

The TorchModel class wraps an nn.Module.

__call__(batch)[source]

Wraps forward pass of torch.nn model.

Parameters:

batch – A PyTorch tensor of shape (batch, timesteps, features*)

Returns:

either a tensor to be compared with targets or passed to

NeuroBenchPostProcessors.

Return type:

preds

__init__(net)[source]

Initializes the TorchModel class.

Parameters:

net – A PyTorch nn.Module.

__net__()[source]

Returns the underlying network.

SNNTorch Models

class neurobench.models.snntorch_models.SNNTorchModel(net)[source]

Bases: NeuroBenchModel

The SNNTorch class wraps the forward pass of the SNNTorch framework and ensures that spikes are in the correct format for downstream NeuroBench components.

__call__(data)[source]

Executes the forward pass of SNNTorch models on data that follows the NeuroBench specification. Ensures spikes are compatible with downstream components.

Parameters:

data – A PyTorch tensor of shape (batch, timesteps, …)

Returns:

A PyTorch tensor of shape (batch, timesteps, …)

Return type:

spikes

__init__(net)[source]

Init using a trained network.

Parameters:

net – A trained SNNTorch network.

__net__()[source]

Returns the underlying network.