Source code for neurobench.datasets.speech_commands

import torch
import os
from glob import glob
from torchaudio.datasets import SPEECHCOMMANDS
from .dataset import NeuroBenchDataset


[docs] class SpeechCommands(NeuroBenchDataset, SPEECHCOMMANDS): """ Speech commands dataset v0.02 with 35 keywords. Wraps the torchaudio SPEECHCOMMANDS dataset. """
[docs] def __init__( self, path, subset: str = None, truncate_or_pad_to_1s=True, ): """ Initializes the SpeechCommands dataset. Args: path (str): path to the root directory of the dataset subset (str, optional): one of "training", "validation", or "testing". Defaults to None. truncate_or_pad_to_1s (bool, optional): whether to truncate or pad samples to 1s. Defaults to True. """ os.makedirs(path, exist_ok=True) SPEECHCOMMANDS.__init__(self, path, download=True, subset=subset) self.truncate_or_pad_to_1s = truncate_or_pad_to_1s # convert labels to indices self.labels = sorted( glob( "*/", root_dir=os.path.join(path, "SpeechCommands", "speech_commands_v0.02"), ) ) # subtract 1 to account for _background_noise_ self.labels = {key[:-1]: idx - 1 for idx, key in enumerate(self.labels)}
[docs] def __getitem__(self, idx): """ Getter method for dataset. Args: idx (int): index of sample to return Returns: waveform (torch.Tensor): waveform of audio sample label (torch.Tensor): label index of audio sample """ ( waveform, sample_rate, label, speaker_id, utterance_num, ) = SPEECHCOMMANDS.__getitem__(self, idx) if self.truncate_or_pad_to_1s: if waveform.shape[0] > sample_rate: waveform = waveform[:sample_rate] else: temp_waveform = torch.zeros((sample_rate,)) temp_waveform[: waveform.numel()] = waveform waveform = temp_waveform waveform = waveform.unsqueeze(-1) label = self.label_to_index(label) return waveform, label
[docs] def label_to_index(self, label): """ Converts a label to an index. Args: label (str): label of audio sample Returns: torch.Tensor: index of label """ return torch.tensor(self.labels[label])
def __len__(self): """ Returns number of samples in dataset. Returns: int: number of samples in dataset """ return SPEECHCOMMANDS.__len__(self)