Source code for neurobench.models.torch_model

import torch
from .model import NeuroBenchModel


[docs] class TorchModel(NeuroBenchModel): """The TorchModel class wraps an nn.Module."""
[docs] def __init__(self, net): """ Initializes the TorchModel class. Args: net: A PyTorch nn.Module. """ super().__init__(net) self.net = net self.net.eval()
[docs] def __call__(self, batch): """ Wraps forward pass of torch.nn model. Args: batch: A PyTorch tensor of shape (batch, timesteps, features*) Returns: preds: either a tensor to be compared with targets or passed to NeuroBenchPostProcessors. """ return self.net(batch)
[docs] def __net__(self): """Returns the underlying network.""" return self.net