from neurobench.datasets.dataset import Dataset
import numpy as np
import torch
import os
from .utils import download_url
from urllib.error import URLError
"""
Preprocessed EEG Motor Imagery (MI) dataset derived from the Lee2019 dataset
(Lee et al., 2019, "EEG dataset and OpenBMI toolbox for three BCI paradigms:
An investigation into BCI illiteracy"), adapted for the THOR challenge.
The original Lee2019 dataset contains EEG recordings from 54 subjects performing
left-hand and right-hand motor imagery tasks. This version provides preprocessed
train/validation splits ready for use in the NeuroBench THOR benchmark.
Data is automatically downloaded from:
https://huggingface.co/datasets/NeuroBench/thor_eeg_mi
Original dataset reference:
Lee, M.-H., et al. (2019). EEG dataset and OpenBMI toolbox for three BCI
paradigms: An investigation into BCI illiteracy.
GigaScience, 8(5), giz002. https://doi.org/10.1093/gigascience/giz002
"""
BASE_URL = "https://huggingface.co/datasets/NeuroBench/thor_eeg_mi/resolve/main"
FILES = {
"train_X": "train_X.npy",
"train_y": "train_y.npy",
"val_X": "val_X.npy",
"val_y": "val_y.npy",
}
[docs]
class ThorEEGMI(Dataset):
"""
Preprocessed Lee2019 EEG Motor Imagery dataset adapted for the THOR challenge.
This dataset is derived from the Lee2019 benchmark (Lee et al., 2019), which
recorded EEG from 54 subjects performing left-hand and right-hand motor imagery
tasks using a 62-channel cap at 1000 Hz. The data here is a preprocessed and
re-split version tailored for the NeuroBench THOR challenge, provided as
train/validation splits of trials with shape ``(n_trials, n_timesteps, n_channels)``.
Each sample is a single motor imagery trial and the target is an integer class
label (0 = right hand, 1 = left hand).
Args:
root (str): Root directory where the dataset files are stored (or
will be downloaded to).
split (str): Which split to load. One of ``"train"`` or ``"val"``.
download (bool): If ``True``, downloads the dataset files from
HuggingFace if they are not already present in ``root``.
Reference:
Lee, M.-H., et al. (2019). EEG dataset and OpenBMI toolbox for three BCI
paradigms: An investigation into BCI illiteracy.
GigaScience, 8(5), giz002. https://doi.org/10.1093/gigascience/giz002
"""
[docs]
def __init__(
self,
root: str,
split: str = "train",
download: bool = True,
):
super().__init__()
if split not in ("train", "val"):
raise ValueError(f"split must be 'train' or 'val', got '{split}'")
self.root = root
self.split = split
os.makedirs(self.root, exist_ok=True)
if download:
self._download()
self._load_data()
def _file_path(self, key: str) -> str:
"""Return the full local path for a dataset file key."""
return os.path.join(self.root, FILES[key])
def _download(self):
"""Download all four dataset files if not already present."""
for key, filename in FILES.items():
dest = self._file_path(key)
if os.path.exists(dest):
continue
url = f"{BASE_URL}/{filename}"
try:
print(f"Downloading {url}")
download_url(url, dest)
except URLError as error:
raise RuntimeError(
f"Failed to download {filename}:\n{error}"
) from error
def _load_data(self):
"""Load the appropriate split into tensors."""
x_key = f"{self.split}_X"
y_key = f"{self.split}_y"
x_path = self._file_path(x_key)
y_path = self._file_path(y_key)
if not os.path.exists(x_path):
raise FileNotFoundError(
f"Data file not found: {x_path}. "
"Re-initialize with download=True to fetch the dataset."
)
if not os.path.exists(y_path):
raise FileNotFoundError(
f"Label file not found: {y_path}. "
"Re-initialize with download=True to fetch the dataset."
)
# Load raw numpy arrays
X = np.load(x_path) # shape: (n_trials, n_timesteps, n_channels)
y = np.load(y_path) # shape: (n_trials,)
# Convert to torch tensors
# X shape: (n_trials, n_timesteps, n_channels)
self.data = torch.tensor(X, dtype=torch.float32)
self.targets = torch.tensor(y, dtype=torch.long)
assert len(self.data) == len(self.targets), (
f"Mismatch between number of samples ({len(self.data)}) "
f"and labels ({len(self.targets)})"
)
def __len__(self) -> int:
"""
Return the number of EEG trials in the split.
Returns:
int: number of samples in the dataset
"""
return len(self.data)
[docs]
def __getitem__(self, idx):
"""
Return a single EEG trial and its label.
Args:
idx (int or list or torch.Tensor): index or indices of the
sample(s) to retrieve.
Returns:
sample (torch.Tensor): EEG trial of shape
``(n_timesteps, n_channels)`` for a single index, or
``(batch, n_timesteps, n_channels)`` for a list/tensor of
indices.
target (torch.Tensor): class label(s), shape ``()`` for a
single index or ``(batch,)`` for multiple indices.
"""
if isinstance(idx, (list, torch.Tensor)):
return self.data[idx], self.targets[idx]
return self.data[idx], self.targets[idx]
@property
def n_timesteps(self) -> int:
"""Number of time steps per trial."""
return self.data.shape[1]
@property
def n_channels(self) -> int:
"""Number of EEG channels."""
return self.data.shape[2]
@property
def n_classes(self) -> int:
"""Number of unique class labels."""
return int(self.targets.max().item()) + 1
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"split={self.split!r}, "
f"n_trials={len(self)}, "
f"n_channels={self.n_channels}, "
f"n_timesteps={self.n_timesteps}, "
f"n_classes={self.n_classes})"
)