"""
This file contains code from PyTorch Vision (https://github.com/pytorch/vision) which is licensed under BSD 3-Clause License.
These snippets are the Copyright (c) of Soumith Chintala 2016. All other code is the Copyright (c) of the NeuroBench Developers 2023.
"""
from .dataset import NeuroBenchDataset
from .utils import check_integrity, download_url
from torch.utils.data import Dataset
import os
import torch
import math
import numpy as np
import h5py
from scipy.signal import convolve2d
from urllib.error import URLError
# The spikes recorded in the Primate Reaching datasets have an interval of 4ms.
SAMPLING_RATE = 4e-3
[docs]
class PrimateReaching(NeuroBenchDataset):
"""
Dataset for the Primate Reaching Task.
The Dataset can be downloaded from the following website:
https://zenodo.org/record/583331
For this task, the following files are selected:
1. indy_20170131_02.mat
2. indy_20160630_01.mat
3. indy_20160622_01.mat
4. loco_20170301_05.mat
5. loco_20170215_02.mat
6. loco_20170210_03.mat
The description of the structure of the dataset can be found on the website
in the section: Variable names.
Once these .mat files are downloaded, store them in the same directory.
"""
url = "https://zenodo.org/record/583331/files/"
md5s = {
"indy_20170131_02.mat": "2790b1c869564afaa7772dbf9e42d784",
"indy_20160630_01.mat": "197413a5339630ea926cbd22b8b43338",
"indy_20160622_01.mat": "c33d5fff31320d709d23fe445561fb6e",
"loco_20170301_05.mat": "47342da09f9c950050c9213c3df38ea3",
"loco_20170215_02.mat": "739b70762d838f3a1f358733c426bb02",
"loco_20170210_03.mat": "4cae63b58c4cb9c8abd44929216c703b",
}
[docs]
def __init__(
self,
file_path,
filename,
num_steps,
train_ratio=0.8,
label_series=False,
biological_delay=0,
spike_sorting=False,
stride=0.004,
bin_width=0.028,
max_segment_length=2000,
split_num=1,
remove_segments_inactive=False,
download=True,
):
"""
Initialises the Dataset for the Primate Reaching Task.
Args:
file_path (str): The path to the directory storing the matlab files.
filename (str): The name of the file that will be loaded.
num_steps (int): Number of consecutive timesteps that are included per sample.
In the real-time case, this should be 1.
train_ratio (float): ratio for how the dataset will be split into training/(val+test) set.
Default is 0.8 (80% of data is training).
label_series (bool): Whether the labels are series or not. Useful for training with multiple
timesteps. Default is False.
biological_delay (int): How many steps of delay is to be applied to the dataset. Default is 0
i.e. no delay applied.
spike_sorting (bool): Apply spike sorting for processing raw spike data. Default is False.
stride (float): How many steps are taken when moving the bin_window. Default is 0.004 (4ms).
bin_width (float): The size of the bin_window. Default is 0.028 (28ms).
max_segment_length: Define the upper limits of a segment. Default is 2000 data points (8s)
split_num (int): The number of chunks to break the timeseries into. Default is 1 (no splits).
remove_segments_inactive (bool): Whether to remove segments longer than max_segment_length,
which represent subject inactivity. Default is False.
download (bool): If True, downloads the dataset from the internet and puts it in root
directory. If dataset is already downloaded, it will not be downloaded again.
"""
self.url = "https://zenodo.org/record/583331/files/"
self.md5s = {
"indy_20170131_02.mat": "2790b1c869564afaa7772dbf9e42d784",
"indy_20160630_01.mat": "197413a5339630ea926cbd22b8b43338",
"indy_20160622_01.mat": "c33d5fff31320d709d23fe445561fb6e",
"loco_20170301_05.mat": "47342da09f9c950050c9213c3df38ea3",
"loco_20170215_02.mat": "739b70762d838f3a1f358733c426bb02",
"loco_20170210_03.mat": "4cae63b58c4cb9c8abd44929216c703b",
}
# The samples and labels of the dataset
self.samples = None
self.labels = None
# used for input data file management
self.filename = filename if filename[-4:] == ".mat" else filename + ".mat"
self.file_path = os.path.join(file_path, self.filename)
if download:
self.download()
# test filepath
assert os.path.exists(self.file_path)
# related to processing of spike data
self.spike_sorting = spike_sorting
self.delay = biological_delay
self.stride = stride
self.bin_width = bin_width
self.num_steps = num_steps
self.train_ratio = train_ratio
self.label_series = label_series
self.ratio = int(np.round(self.bin_width / SAMPLING_RATE))
# test parameters
assert self.delay >= 0
assert self.stride >= SAMPLING_RATE
assert (
self.bin_width >= SAMPLING_RATE
), "The binning window has to be greater than the sampling size (i.e. 0.004s)"
assert self.num_steps >= 1
assert 0 <= self.train_ratio <= 1
# Defines the beginning and end of each segment.
self.start_end_indices = None
self.time_segments = None
# Defines the maximum length of a segment.
self.max_segment_length = max_segment_length
assert self.max_segment_length >= 0
self.split_num = split_num
# These lists store the index of segments that belongs to training/validation/test set
self.ind_train, self.ind_val, self.ind_test = [], [], []
if "indy" in filename:
self.input_feature_size = 96
elif "loco" in filename:
self.input_feature_size = 192
else:
raise ValueError(
"Unexpected filename. Filename should be of either indy or loco"
)
self.load_data()
if self.delay > 0:
self.apply_delay()
if remove_segments_inactive and self.max_segment_length > 0:
self.valid_segments = self.remove_segments_by_length()
else:
self.valid_segments = np.arange(self.time_segments.shape[0])
self.split_data()
def __len__(self):
return len(self.ind_train) + len(self.ind_test) + len(self.ind_val)
[docs]
def __getitem__(self, idx):
"""Getter method of the dataloader."""
# compute indices of congruent binning windows
mask = idx - np.arange(self.num_steps) * self.ratio
if self.label_series:
samples = self.samples[:, mask].transpose(0, 1)
labels = self.labels[:, mask].transpose(0, 1)
return samples, labels
else:
return self.samples[:, mask].transpose(0, 1), self.labels[:, idx]
def _check_exists(self, file_path, md5) -> bool:
return check_integrity(file_path, md5)
[docs]
def download(self):
"""Download the Primate Reaching data if it doesn't exist already."""
md5 = self.md5s[self.filename]
if self._check_exists(self.file_path, md5):
return
os.makedirs(os.path.dirname(self.file_path), exist_ok=True)
# download file
url = f"{self.url}{self.filename}"
try:
print(f"Downloading {url}")
download_url(url, self.file_path, md5=md5)
except URLError as error:
print(f"Failed to download (trying next):\n{error}")
finally:
print()
[docs]
def load_data(self):
"""Load the data from the matlab file and spike data if spike data has been
processed and stored already."""
# Assume input is the original dataset, instead of the reconstructed one
print(f"Loading {self.filename}")
dataset = h5py.File(self.file_path, "r")
# extract data from datafile
spikes = dataset["spikes"][
()
] # Get the reference object's locations in the HDF5/mat file
cursor_pos = dataset["cursor_pos"][()]
target_pos = dataset["target_pos"][()]
t = np.squeeze(dataset["t"][()])
new_t = np.arange(t[0] - self.bin_width, t[-1], SAMPLING_RATE)
# Define the segments' start & end indices
self.start_end_indices = np.array(self.get_flag_index(target_pos))
self.time_segments = np.array(
self.split_into_segments(self.start_end_indices, target_pos.shape[1])
)
spike_train = np.zeros((*spikes.shape, len(new_t)), dtype=np.int8)
# iterate over hdf5 dataframe and preprocess data
for row_idx, row in enumerate(spikes):
for col_idx, element in enumerate(row):
# get indices of spikes and convert data to spike train
if isinstance(element, np.ndarray):
bins, _ = np.histogram(element, bins=new_t.squeeze())
else:
bins, _ = np.histogram(dataset[element][()], bins=new_t.squeeze())
# histogram is assigns spikes to lower bound of binning window, therefor increment by one to shift to
# upper bound
idx = np.nonzero(bins)[0] + 1
spike_train[row_idx, col_idx, idx] = 1
if self.spike_sorting:
# if using spike sorting, reshape # channels x # units into a single dimension => # features
spike_train = np.transpose(spike_train, (2, 1, 0)).reshape(t.shape[1], -1)
# remove empty channels
spike_train = spike_train[:, spike_train.any(axis=0)]
spike_train = spike_train.transpose()
else:
# combine units into channels
spike_train = np.bitwise_or.reduce(spike_train, axis=0)
# use convolution to compute binning window
if self.ratio != 1:
binned_spike_train = convolve2d(
spike_train, np.ones((1, self.ratio)), mode="valid"
)
else:
binned_spike_train = spike_train
# Dimensions: (channels x timesteps)
self.samples = torch.from_numpy(binned_spike_train).float()
# Dimensions: (nr_features x timesteps)
self.labels = torch.from_numpy(cursor_pos).float()
# convert position to velocity
self.labels = torch.gradient(self.labels, dim=1)[0]
[docs]
def apply_delay(self):
"""Shift the labels by the delay to account for the biological delay between
spikes and movement onset."""
# Dimension: No_of_Channels*No_of_Records
self.samples = self.samples[:, : -self.delay]
self.labels = self.labels[:, self.delay :]
[docs]
def split_data(self):
"""Split segments into training/validation/test set."""
# This is No. of chunks
split_num = self.split_num
total_segments = self.time_segments.shape[0]
sub_length = int(
total_segments / split_num
) # This is no of segments in each chunk
stride = int(self.stride / SAMPLING_RATE)
# print(total_segments, sub_length)
train_len = math.floor(self.train_ratio * sub_length)
val_len = math.floor((sub_length - train_len) / 2)
# offset = int(np.round(self.bin_width / SAMPLING_RATE)) * self.num_steps
offset = 0
# split the data into 4 equal parts
# for each part, split the data according to training, testing and validation split
for split_no in range(split_num):
for i in range(sub_length):
# Each segment's Dimension is: No_of_Probes * No_of_Recording
if i < train_len and i in self.valid_segments:
self.ind_train += list(
np.arange(
offset + self.time_segments[split_no * sub_length + i, 0],
self.time_segments[split_no * sub_length + i, 1],
stride,
)
)
elif train_len <= i < train_len + val_len and i in self.valid_segments:
self.ind_val += list(
np.arange(
offset + self.time_segments[split_no * sub_length + i, 0],
self.time_segments[split_no * sub_length + i, 1],
stride,
)
)
elif i in self.valid_segments:
self.ind_test += list(
np.arange(
offset + self.time_segments[split_no * sub_length + i, 0],
self.time_segments[split_no * sub_length + i, 1],
stride,
)
)
[docs]
def remove_segments_by_length(self):
"""Remove the segments where its duration exceeds the limit set by
max_segment_length."""
return np.nonzero(
self.time_segments[:, 1] - self.time_segments[:, 0]
< self.max_segment_length
)[0]
[docs]
@staticmethod
def split_into_segments(indices, last_idx):
"""Combine the start and end index into a NumPy array."""
indices = np.insert(indices, 0, 0)
indices = np.append(indices, [last_idx])
start_end = np.array([indices[:-1], indices[1:]])
return np.transpose(start_end)
[docs]
@staticmethod
def get_flag_index(target_pos):
"""Find where each segment begins and ends."""
target_diff = np.diff(
target_pos, axis=1, append=target_pos[:, -1].reshape(2, 1)
)
indices = np.nonzero(np.sum(np.abs(target_diff), axis=0))[0]
return indices